Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
09e6e71c
Commit
09e6e71c
authored
Mar 02, 2022
by
Zihan Wang
Browse files
lint
parent
32867f40
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
755 additions
and
631 deletions
+755
-631
official/projects/longformer/longformer_attention.py
official/projects/longformer/longformer_attention.py
+260
-171
official/projects/longformer/longformer_attention_test.py
official/projects/longformer/longformer_attention_test.py
+98
-60
official/projects/longformer/longformer_encoder.py
official/projects/longformer/longformer_encoder.py
+141
-151
official/projects/longformer/longformer_encoder_block.py
official/projects/longformer/longformer_encoder_block.py
+122
-136
official/projects/longformer/longformer_encoder_test.py
official/projects/longformer/longformer_encoder_test.py
+51
-36
official/projects/longformer/longformer_experiments.py
official/projects/longformer/longformer_experiments.py
+71
-65
official/projects/longformer/train.py
official/projects/longformer/train.py
+12
-12
No files found.
official/projects/longformer/longformer_attention.py
View file @
09e6e71c
This diff is collapsed.
Click to expand it.
official/projects/longformer/longformer_attention_test.py
View file @
09e6e71c
...
@@ -12,25 +12,26 @@
...
@@ -12,25 +12,26 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tests for
the
attention
layer
."""
"""Tests for
official.nlp.projects.longformer.longformer_
attention."""
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.projects.longformer
import
longformer_attention
from
official.projects.longformer
import
longformer_attention
from
official.modeling.tf_utils
import
get_shape_list
from
official.modeling.tf_utils
import
get_shape_list
def
_create_mock_attention_data
(
def
_create_mock_attention_data
(
num_heads
,
num_heads
,
key_dim
,
key_dim
,
value_dim
,
value_dim
,
q_seq_length
,
q_seq_length
,
kv_seq_length
,
kv_seq_length
,
batch_size
,
batch_size
,
include_mask
=
False
):
include_mask
=
False
):
"""Creates mock testing data.
"""Creates mock testing data.
Args:
Args:
...
@@ -48,15 +49,15 @@ def _create_mock_attention_data(
...
@@ -48,15 +49,15 @@ def _create_mock_attention_data(
value_shape
=
(
batch_size
,
kv_seq_length
,
value_dim
)
value_shape
=
(
batch_size
,
kv_seq_length
,
value_dim
)
data
=
dict
(
data
=
dict
(
query
=
tf
.
random
.
normal
(
shape
=
query_shape
),
query
=
tf
.
random
.
normal
(
shape
=
query_shape
),
value
=
tf
.
random
.
normal
(
shape
=
value_shape
),
value
=
tf
.
random
.
normal
(
shape
=
value_shape
),
key
=
tf
.
random
.
normal
(
shape
=
value_shape
))
key
=
tf
.
random
.
normal
(
shape
=
value_shape
))
total_seq_length
=
kv_seq_length
total_seq_length
=
kv_seq_length
if
include_mask
:
if
include_mask
:
mask_shape
=
(
batch_size
,
num_heads
,
q_seq_length
,
total_seq_length
)
mask_shape
=
(
batch_size
,
num_heads
,
q_seq_length
,
total_seq_length
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"
float32
"
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
'
float32
'
)
mask_data
=
dict
(
attention_mask
=
mask_data
)
mask_data
=
dict
(
attention_mask
=
mask_data
)
data
.
update
(
mask_data
)
data
.
update
(
mask_data
)
...
@@ -65,6 +66,12 @@ def _create_mock_attention_data(
...
@@ -65,6 +66,12 @@ def _create_mock_attention_data(
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
LongformerAttentionTest
(
keras_parameterized
.
TestCase
):
class
LongformerAttentionTest
(
keras_parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerAttentionTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
def
_get_hidden_states
(
self
):
def
_get_hidden_states
(
self
):
return
tf
.
convert_to_tensor
(
return
tf
.
convert_to_tensor
(
[
[
...
@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
...
@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def
test_diagonalize
(
self
):
def
test_diagonalize
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
1
,
8
,
4
))
# set seq length = 8, hidden dim = 4
hidden_states
=
tf
.
reshape
(
hidden_states
,
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
(
1
,
8
,
4
))
# set seq length = 8, hidden dim = 4
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
window_overlap_size
=
get_shape_list
(
chunked_hidden_states
)[
2
]
window_overlap_size
=
get_shape_list
(
chunked_hidden_states
)[
2
]
self
.
assertTrue
(
window_overlap_size
==
4
)
self
.
assertTrue
(
window_overlap_size
==
4
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_diagonalize
(
chunked_hidden_states
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_diagonalize
(
chunked_hidden_states
)
self
.
assertTrue
(
self
.
assertTrue
(
get_shape_list
(
padded_hidden_states
)[
-
1
]
==
get_shape_list
(
chunked_hidden_states
)[
-
1
]
+
window_overlap_size
-
1
get_shape_list
(
padded_hidden_states
)[
-
1
]
==
get_shape_list
(
chunked_hidden_states
)[
-
1
]
+
window_overlap_size
-
1
)
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
:
4
],
chunked_hidden_states
[
0
,
0
,
0
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
:
4
],
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
4
:],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
chunked_hidden_states
[
0
,
0
,
0
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
4
:],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
3
:],
chunked_hidden_states
[
0
,
0
,
-
1
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
3
:],
chunked_hidden_states
[
0
,
0
,
-
1
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
:
3
],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
padded_hidden_states
[
0
,
0
,
-
1
,
:
3
],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
)
def
test_pad_and_transpose_last_two_dims
(
self
):
def
test_pad_and_transpose_last_two_dims
(
self
):
...
@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
...
@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
self
.
assertTrue
(
get_shape_list
(
hidden_states
),
[
1
,
8
,
4
])
self
.
assertTrue
(
get_shape_list
(
hidden_states
),
[
1
,
8
,
4
])
# pad along seq length dim
# pad along seq length dim
paddings
=
tf
.
constant
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]],
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
constant
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]],
dtype
=
tf
.
dtypes
.
int32
)
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_transpose_last_two_dims
(
hidden_states
,
paddings
)
hidden_states
,
window_overlap
=
2
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_transpose_last_two_dims
(
hidden_states
,
paddings
)
self
.
assertTrue
(
get_shape_list
(
padded_hidden_states
)
==
[
1
,
1
,
8
,
5
])
self
.
assertTrue
(
get_shape_list
(
padded_hidden_states
)
==
[
1
,
1
,
8
,
5
])
expected_added_dim
=
tf
.
zeros
((
5
,),
dtype
=
tf
.
dtypes
.
float32
)
expected_added_dim
=
tf
.
zeros
((
5
,),
dtype
=
tf
.
dtypes
.
float32
)
tf
.
debugging
.
assert_near
(
expected_added_dim
,
padded_hidden_states
[
0
,
0
,
-
1
,
:],
rtol
=
1e-6
)
tf
.
debugging
.
assert_near
(
expected_added_dim
,
padded_hidden_states
[
0
,
0
,
-
1
,
:],
rtol
=
1e-6
)
tf
.
debugging
.
assert_near
(
tf
.
debugging
.
assert_near
(
hidden_states
[
0
,
0
,
-
1
,
:],
tf
.
reshape
(
padded_hidden_states
,
(
1
,
-
1
))[
0
,
24
:
32
],
rtol
=
1e-6
hidden_states
[
0
,
0
,
-
1
,
:],
tf
.
reshape
(
padded_hidden_states
,
(
1
,
-
1
))[
0
,
24
:
32
],
rtol
=
1e-6
)
)
def
test_mask_invalid_locations
(
self
):
def
test_mask_invalid_locations
(
self
):
...
@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
...
@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
batch_size
=
1
batch_size
=
1
seq_length
=
8
seq_length
=
8
hidden_size
=
4
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
tf
.
reshape
(
hidden_states
,
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hid_states_1
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
1
)
hidden_states
,
window_overlap
=
2
)
hid_states_2
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
2
)
hid_states_3
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
:,
:
3
],
2
)
hid_states_1
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hid_states_4
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
2
:,
:],
2
)
hidden_states
,
1
)
hid_states_2
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_1
),
tf
.
dtypes
.
int32
))
==
8
)
hidden_states
,
2
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_2
),
tf
.
dtypes
.
int32
))
==
24
)
hid_states_3
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_3
),
tf
.
dtypes
.
int32
))
==
24
)
hidden_states
[:,
:,
:,
:
3
],
2
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_4
),
tf
.
dtypes
.
int32
))
==
12
)
hid_states_4
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
2
:,
:],
2
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_1
),
tf
.
dtypes
.
int32
))
==
8
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_2
),
tf
.
dtypes
.
int32
))
==
24
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_3
),
tf
.
dtypes
.
int32
))
==
24
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_4
),
tf
.
dtypes
.
int32
))
==
12
)
def
test_chunk
(
self
):
def
test_chunk
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
self
.
_get_hidden_states
()
batch_size
=
1
batch_size
=
1
seq_length
=
8
seq_length
=
8
hidden_size
=
4
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
# expected slices across chunk and seq length dim
# expected slices across chunk and seq length dim
expected_slice_along_seq_length
=
tf
.
convert_to_tensor
([
0.4983
,
-
0.7584
,
-
1.6944
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_seq_length
=
tf
.
convert_to_tensor
(
expected_slice_along_chunk
=
tf
.
convert_to_tensor
([
0.4983
,
-
1.8348
,
-
0.7584
,
2.0514
],
dtype
=
tf
.
dtypes
.
float32
)
[
0.4983
,
-
0.7584
,
-
1.6944
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_chunk
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
1.8348
,
-
0.7584
,
2.0514
],
dtype
=
tf
.
dtypes
.
float32
)
self
.
assertTrue
(
get_shape_list
(
chunked_hidden_states
)
==
[
1
,
3
,
4
,
4
])
self
.
assertTrue
(
get_shape_list
(
chunked_hidden_states
)
==
[
1
,
3
,
4
,
4
])
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
:,
0
,
0
],
expected_slice_along_seq_length
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
:,
0
,
0
],
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
0
,
:,
0
],
expected_slice_along_chunk
,
rtol
=
1e-3
)
expected_slice_along_seq_length
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
0
,
:,
0
],
expected_slice_along_chunk
,
rtol
=
1e-3
)
def
test_layer_local_attn
(
self
):
def
test_layer_local_attn
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
self
.
_get_hidden_states
()
batch_size
,
seq_length
,
hidden_size
=
hidden_states
.
shape
batch_size
,
seq_length
,
_
=
hidden_states
.
shape
layer
=
longformer_attention
.
LongformerAttention
(
layer
=
longformer_attention
.
LongformerAttention
(
num_heads
=
2
,
num_heads
=
2
,
key_dim
=
4
,
key_dim
=
4
,
...
@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
...
@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
attention_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
,
1
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
,
1
)
is_global_attn
=
tf
.
math
.
reduce_any
(
is_index_global_attn
)
attention_mask
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
1
,
-
10000.0
,
attention_mask
[:,
:,
None
,
None
])
attention_mask
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
1
,
-
10000.0
,
attention_mask
[:,
:,
None
,
None
])
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
output_hidden_states
=
layer
(
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
is_global_attn
=
is_global_attn
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
1
,
4
,
8
))
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
1
,
4
,
8
))
...
@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
...
@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
)
)
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
concat
([
self
.
_get_hidden_states
(),
self
.
_get_hidden_states
()
-
0.5
],
axis
=
0
)
hidden_states
=
tf
.
concat
(
[
self
.
_get_hidden_states
(),
self
.
_get_hidden_states
()
-
0.5
],
axis
=
0
)
batch_size
,
seq_length
,
hidden_size
=
hidden_states
.
shape
batch_size
,
seq_length
,
hidden_size
=
hidden_states
.
shape
# create attn mask
# create attn mask
attention_mask_1
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_1
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_2
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_2
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_1
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
2
,
-
10000.0
,
attention_mask_1
)
attention_mask_1
)
attention_mask_2
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_2
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
2
,
-
10000.0
,
attention_mask_1
)
attention_mask_2
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_2
)
attention_mask
=
tf
.
concat
([
attention_mask_1
,
attention_mask_2
],
axis
=
0
)
attention_mask
=
tf
.
concat
([
attention_mask_1
,
attention_mask_2
],
axis
=
0
)
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_global_attn
=
tf
.
math
.
reduce_any
(
is_index_global_attn
)
output_hidden_states
=
layer
(
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=-
tf
.
math
.
abs
(
attention_mask
),
hidden_states
=
hidden_states
,
attention_mask
=-
tf
.
math
.
abs
(
attention_mask
),
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
is_global_attn
=
is_global_attn
,
is_index_masked
=
is_index_masked
,
)[
0
]
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
2
,
4
,
8
))
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
2
,
4
,
8
))
if
__name__
==
"__main__"
:
if
__name__
==
'__main__'
:
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
tf
.
test
.
main
()
tf
.
test
.
main
()
official/projects/longformer/longformer_encoder.py
View file @
09e6e71c
...
@@ -23,29 +23,16 @@ from absl import logging
...
@@ -23,29 +23,16 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.projects.longformer.longformer_encoder_block
import
LongformerEncoderBlock
from
official.projects.longformer.longformer_encoder_block
import
\
LongformerEncoderBlock
from
official.modeling.tf_utils
import
get_shape_list
from
official.modeling.tf_utils
import
get_shape_list
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
# Transferred from huggingface.longformer.TFLongformerMainLayer & TFLongformerEncoder
class
LongformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
class
LongformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""Bi-directional Transformer-based encoder network.
"""LongformerEncoder
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Args:
Args:
vocab_size: The size of the token vocabulary.
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
attention_window: list of ints representing the window size for each layer.
...
@@ -85,27 +72,27 @@ class LongformerEncoder(tf.keras.layers.Layer):
...
@@ -85,27 +72,27 @@ class LongformerEncoder(tf.keras.layers.Layer):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
:
int
,
vocab_size
:
int
,
attention_window
:
Union
[
List
[
int
],
int
]
=
512
,
attention_window
:
Union
[
List
[
int
],
int
]
=
512
,
global_attention_size
:
int
=
0
,
global_attention_size
:
int
=
0
,
pad_token_id
:
int
=
1
,
pad_token_id
:
int
=
1
,
hidden_size
:
int
=
768
,
hidden_size
:
int
=
768
,
num_layers
:
int
=
12
,
num_layers
:
int
=
12
,
num_attention_heads
:
int
=
12
,
num_attention_heads
:
int
=
12
,
max_sequence_length
:
int
=
512
,
max_sequence_length
:
int
=
512
,
type_vocab_size
:
int
=
16
,
type_vocab_size
:
int
=
16
,
inner_dim
:
int
=
3072
,
inner_dim
:
int
=
3072
,
inner_activation
:
Callable
[...,
Any
]
=
_approx_gelu
,
inner_activation
:
Callable
[...,
Any
]
=
_approx_gelu
,
output_dropout
:
float
=
0.1
,
output_dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
initializer
:
_Initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
:
_Initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
stddev
=
0.02
),
output_range
:
Optional
[
int
]
=
None
,
output_range
:
Optional
[
int
]
=
None
,
embedding_width
:
Optional
[
int
]
=
None
,
embedding_width
:
Optional
[
int
]
=
None
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
norm_first
:
bool
=
False
,
norm_first
:
bool
=
False
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
# Longformer args
# Longformer args
self
.
_attention_window
=
attention_window
self
.
_attention_window
=
attention_window
...
@@ -120,93 +107,91 @@ class LongformerEncoder(tf.keras.layers.Layer):
...
@@ -120,93 +107,91 @@ class LongformerEncoder(tf.keras.layers.Layer):
if
embedding_layer
is
None
:
if
embedding_layer
is
None
:
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
name
=
'word_embeddings'
)
else
:
else
:
self
.
_embedding_layer
=
embedding_layer
self
.
_embedding_layer
=
embedding_layer
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
initializer
,
initializer
=
initializer
,
max_length
=
max_sequence_length
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
name
=
'position_embedding'
)
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
initializer
=
initializer
,
use_one_hot
=
True
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
name
=
'type_embeddings'
)
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
self
.
_embedding_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
_embedding_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
output_dropout
,
name
=
'embedding_dropout'
)
rate
=
output_dropout
,
name
=
'embedding_dropout'
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
# 'hidden_size'.
self
.
_embedding_projection
=
None
self
.
_embedding_projection
=
None
if
embedding_width
!=
hidden_size
:
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
'...x,xy->...y'
,
output_shape
=
hidden_size
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
name
=
'embedding_projection'
)
self
.
_transformer_layers
=
[]
self
.
_transformer_layers
=
[]
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
name
=
'self_attention_mask'
)
name
=
'self_attention_mask'
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
layer
=
LongformerEncoderBlock
(
layer
=
LongformerEncoderBlock
(
global_attention_size
=
global_attention_size
,
global_attention_size
=
global_attention_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
inner_activation
=
inner_activation
,
# Longformer, instead of passing a list of attention_window, pass a value to sub-block
attention_window
=
attention_window
[
i
],
attention_window
=
attention_window
if
isinstance
(
attention_window
,
int
)
else
attention_window
[
i
],
layer_id
=
i
,
layer_id
=
i
,
output_dropout
=
output_dropout
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
norm_first
=
norm_first
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
name
=
f
'transformer/layer_
{
i
}
'
)
name
=
'transformer/layer_%d'
%
i
)
self
.
_transformer_layers
.
append
(
layer
)
self
.
_transformer_layers
.
append
(
layer
)
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
hidden_size
,
units
=
hidden_size
,
activation
=
'tanh'
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
name
=
'pooler_transform'
)
self
.
_config
=
{
self
.
_config
=
{
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'type_vocab_size'
:
type_vocab_size
,
'inner_dim'
:
inner_dim
,
'inner_dim'
:
inner_dim
,
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'output_dropout'
:
output_dropout
,
'output_dropout'
:
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'output_range'
:
output_range
,
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
'norm_first'
:
norm_first
,
# Longformer
'attention_window'
:
attention_window
,
'attention_window'
:
attention_window
,
'global_attention_size'
:
global_attention_size
,
'global_attention_size'
:
global_attention_size
,
'pad_token_id'
:
pad_token_id
,
'pad_token_id'
:
pad_token_id
,
}
}
self
.
inputs
=
dict
(
self
.
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
))
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
))
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
word_embeddings
=
None
word_embeddings
=
None
...
@@ -214,22 +199,23 @@ class LongformerEncoder(tf.keras.layers.Layer):
...
@@ -214,22 +199,23 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_ids
=
inputs
.
get
(
'input_word_ids'
)
# input_ids
word_ids
=
inputs
.
get
(
'input_word_ids'
)
# input_ids
mask
=
inputs
.
get
(
'input_mask'
)
# attention_mask
mask
=
inputs
.
get
(
'input_mask'
)
# attention_mask
type_ids
=
inputs
.
get
(
'input_type_ids'
)
# token_type_ids
type_ids
=
inputs
.
get
(
'input_type_ids'
)
# token_type_ids
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
# input_embeds
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
# input_embeds
else
:
else
:
raise
ValueError
(
'Unexpected inputs type to
%s.'
%
self
.
__class__
)
raise
ValueError
(
f
'Unexpected inputs type to
{
self
.
__class__
}
.'
)
(
(
padding_len
,
padding_len
,
word_ids
,
word_ids
,
mask
,
mask
,
type_ids
,
type_ids
,
word_embeddings
,
word_embeddings
,
)
=
self
.
_pad_to_window_size
(
)
=
self
.
_pad_to_window_size
(
word_ids
=
word_ids
,
word_ids
=
word_ids
,
mask
=
mask
,
mask
=
mask
,
type_ids
=
type_ids
,
type_ids
=
type_ids
,
word_embeddings
=
word_embeddings
,
word_embeddings
=
word_embeddings
,
pad_token_id
=
self
.
_pad_token_id
pad_token_id
=
self
.
_pad_token_id
)
)
if
word_embeddings
is
None
:
if
word_embeddings
is
None
:
...
@@ -247,46 +233,47 @@ class LongformerEncoder(tf.keras.layers.Layer):
...
@@ -247,46 +233,47 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size
,
seq_len
=
get_shape_list
(
mask
)
batch_size
,
seq_len
=
get_shape_list
(
mask
)
# create masks with fixed len global_attention_size
# create masks with fixed len global_attention_size
mask
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
int32
)
*
2
,
mask
=
tf
.
transpose
(
tf
.
concat
(
tf
.
transpose
(
mask
)[
self
.
_global_attention_size
:]],
axis
=
0
))
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
int32
)
*
2
,
tf
.
transpose
(
mask
)[
self
.
_global_attention_size
:]],
axis
=
0
))
is_index_masked
=
tf
.
math
.
less
(
mask
,
1
)
is_index_masked
=
tf
.
math
.
less
(
mask
,
1
)
is_index_global_attn
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
is_index_global_attn
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
),
tf
.
zeros
((
seq_len
-
self
.
_global_attention_size
,
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
),
batch_size
),
tf
.
bool
)
tf
.
zeros
((
seq_len
-
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
)
],
axis
=
0
))
],
axis
=
0
))
is_global_attn
=
self
.
_global_attention_size
>
0
# Longformer
# Longformer
attention_mask
=
mask
attention_mask
=
mask
extended_attention_mask
=
tf
.
reshape
(
extended_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
tf
.
shape
(
mask
)[
0
],
tf
.
shape
(
mask
)[
1
],
1
,
1
)
attention_mask
,
(
tf
.
shape
(
mask
)[
0
],
tf
.
shape
(
mask
)[
1
],
1
,
1
)
)
)
attention_mask
=
tf
.
cast
(
tf
.
math
.
abs
(
1
-
extended_attention_mask
),
tf
.
dtypes
.
float32
)
*
-
10000.0
attention_mask
=
tf
.
cast
(
tf
.
math
.
abs
(
1
-
extended_attention_mask
),
tf
.
dtypes
.
float32
)
*
-
10000.0
encoder_outputs
=
[]
encoder_outputs
=
[]
x
=
embeddings
x
=
embeddings
# TFLongformerEncoder
# TFLongformerEncoder
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
)
:
for
layer
in
self
.
_transformer_layers
:
x
=
layer
([
x
=
layer
([
x
,
x
,
attention_mask
,
attention_mask
,
is_index_masked
,
is_index_masked
,
is_index_global_attn
,
is_index_global_attn
])
is_global_attn
])
encoder_outputs
.
append
(
x
)
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
last_encoder_output
=
encoder_outputs
[
-
1
]
if
padding_len
>
0
:
if
padding_len
>
0
:
last_encoder_output
=
last_encoder_output
[:,
:
-
padding_len
]
last_encoder_output
=
last_encoder_output
[:,
:
-
padding_len
]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
return
dict
(
return
dict
(
sequence_output
=
last_encoder_output
,
sequence_output
=
last_encoder_output
,
pooled_output
=
pooled_output
,
pooled_output
=
pooled_output
,
encoder_outputs
=
encoder_outputs
)
encoder_outputs
=
encoder_outputs
)
def
get_embedding_table
(
self
):
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
return
self
.
_embedding_layer
.
embeddings
...
@@ -311,36 +298,36 @@ class LongformerEncoder(tf.keras.layers.Layer):
...
@@ -311,36 +298,36 @@ class LongformerEncoder(tf.keras.layers.Layer):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'embedding_layer'
in
config
and
config
[
'embedding_layer'
]
is
not
None
:
if
'embedding_layer'
in
config
and
config
[
'embedding_layer'
]
is
not
None
:
warn_string
=
(
warn_string
=
(
'You are reloading a model that was saved with a '
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.'
)
'To work around this, load the model outside of the Keras API.'
)
print
(
'WARNING: '
+
warn_string
)
print
(
'WARNING: '
+
warn_string
)
logging
.
warn
(
warn_string
)
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
return
cls
(
**
config
)
def
_pad_to_window_size
(
def
_pad_to_window_size
(
self
,
self
,
word_ids
,
word_ids
,
mask
,
mask
,
type_ids
,
type_ids
,
word_embeddings
,
word_embeddings
,
pad_token_id
,
pad_token_id
,
):
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
# padding
attention_window
=
(
attention_window
=
max
(
self
.
_attention_window
)
self
.
_attention_window
if
isinstance
(
self
.
_attention_window
,
int
)
else
max
(
self
.
_attention_window
)
)
assert
attention_window
%
2
==
0
,
f
"`attention_window` should be an even value. Given
{
attention_window
}
"
assert
attention_window
%
2
==
0
,
\
f
'`attention_window` should be an even value. Given
{
attention_window
}
'
input_shape
=
get_shape_list
(
word_ids
)
if
word_ids
is
not
None
else
get_shape_list
(
word_embeddings
)
input_shape
=
get_shape_list
(
word_ids
)
if
word_ids
is
not
None
else
get_shape_list
(
word_embeddings
)
batch_size
,
seq_len
=
input_shape
[:
2
]
batch_size
,
seq_len
=
input_shape
[:
2
]
if
seq_len
is
not
None
:
if
seq_len
is
not
None
:
padding_len
=
(
attention_window
-
seq_len
%
attention_window
)
%
attention_window
padding_len
=
(
attention_window
-
seq_len
%
attention_window
)
%
attention_window
else
:
else
:
padding_len
=
0
padding_len
=
0
...
@@ -355,14 +342,17 @@ class LongformerEncoder(tf.keras.layers.Layer):
...
@@ -355,14 +342,17 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings_padding
=
self
.
_embedding_layer
(
word_ids_padding
)
word_embeddings_padding
=
self
.
_embedding_layer
(
word_ids_padding
)
return
tf
.
concat
([
word_embeddings
,
word_embeddings_padding
],
axis
=-
2
)
return
tf
.
concat
([
word_embeddings
,
word_embeddings_padding
],
axis
=-
2
)
word_embeddings
=
tf
.
cond
(
tf
.
math
.
greater
(
padding_len
,
0
),
pad_embeddings
,
lambda
:
word_embeddings
)
word_embeddings
=
tf
.
cond
(
tf
.
math
.
greater
(
padding_len
,
0
),
pad_embeddings
,
lambda
:
word_embeddings
)
mask
=
tf
.
pad
(
mask
,
paddings
,
constant_values
=
False
)
# no attention on the padding tokens
mask
=
tf
.
pad
(
mask
,
paddings
,
token_type_ids
=
tf
.
pad
(
type_ids
,
paddings
,
constant_values
=
0
)
# pad with token_type_id = 0
constant_values
=
False
)
# no attention on the padding tokens
token_type_ids
=
tf
.
pad
(
type_ids
,
paddings
,
constant_values
=
0
)
# pad with token_type_id = 0
return
(
return
(
padding_len
,
padding_len
,
word_ids
,
word_ids
,
mask
,
mask
,
token_type_ids
,
token_type_ids
,
word_embeddings
,)
word_embeddings
,)
official/projects/longformer/longformer_encoder_block.py
View file @
09e6e71c
...
@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
...
@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
"""
"""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.projects.longformer.longformer_attention
import
LongformerAttention
from
official.projects.longformer.longformer_attention
import
\
LongformerAttention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
LongformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
class
LongformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""TransformerEncoderBlock layer.
"""LongformerEncoderBlock.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def
__init__
(
self
,
global_attention_size
,
num_attention_heads
,
inner_dim
,
inner_activation
,
# Longformer
attention_window
,
layer_id
=
0
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
"""Initializes `TransformerEncoderBlock`.
Args:
Args:
num_attention_heads: Number of attention heads.
num_attention_heads: Number of attention heads.
...
@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention over all axes, but batch, heads, and features.
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
**kwargs: keyword arguments/
"""
"""
def
__init__
(
self
,
global_attention_size
,
num_attention_heads
,
inner_dim
,
inner_activation
,
# Longformer
attention_window
,
layer_id
=
0
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
global_attention_size
=
global_attention_size
self
.
global_attention_size
=
global_attention_size
...
@@ -121,7 +111,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -121,7 +111,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
self
.
_inner_dropout
=
inner_dropout
self
.
_inner_dropout
=
inner_dropout
if
attention_initializer
:
if
attention_initializer
:
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
attention_initializer
)
else
:
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
self
.
_attention_axes
=
attention_axes
...
@@ -133,58 +123,58 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -133,58 +123,58 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"The type of input shape argument is not supported, got:
%s"
%
f
"The type of input shape argument is not supported, got:
"
type
(
input_shape
))
f
"
{
type
(
input_shape
)
}
"
)
einsum_equation
=
"abc,cd->abd"
einsum_equation
=
"abc,cd->abd"
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
einsum_equation
=
"...bc,cd->...bd"
einsum_equation
=
"...bc,cd->...bd"
hidden_size
=
input_tensor_shape
[
-
1
]
hidden_size
=
input_tensor_shape
[
-
1
]
if
hidden_size
%
self
.
_num_heads
!=
0
:
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The input size (
%d
) is not a multiple of the number of attention "
f
"The input size (
{
hidden_size
}
) is not a multiple of the number of attention "
"heads (
%d)"
%
(
hidden_size
,
self
.
_num_heads
)
)
f
"heads (
{
self
.
_num_heads
}
)"
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
bias_constraint
=
self
.
_bias_constraint
)
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense
self
.
_attention_layer
=
LongformerAttention
(
self
.
_attention_layer
=
LongformerAttention
(
# Longformer
# Longformer
layer_id
=
self
.
_layer_id
,
layer_id
=
self
.
_layer_id
,
global_attention_size
=
self
.
global_attention_size
,
global_attention_size
=
self
.
global_attention_size
,
attention_window
=
self
.
_attention_window
,
attention_window
=
self
.
_attention_window
,
num_heads
=
self
.
_num_heads
,
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_attention_head_size
,
key_dim
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
kernel_initializer
=
self
.
_attention_initializer
,
attention_axes
=
self
.
_attention_axes
,
attention_axes
=
self
.
_attention_axes
,
name
=
"self_attention"
,
name
=
"self_attention"
,
**
common_kwargs
)
**
common_kwargs
)
# TFLongformerSelfOutput.dropout
# TFLongformerSelfOutput.dropout
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# TFLongformerSelfOutput.Layernorm
# TFLongformerSelfOutput.Layernorm
self
.
_attention_layer_norm
=
(
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
# TFLongformerIntermediate
# TFLongformerIntermediate
# TFLongformerIntermediate.dense
# TFLongformerIntermediate.dense
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
output_shape
=
(
None
,
self
.
_inner_dim
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
name
=
"intermediate"
,
**
common_kwargs
)
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
@@ -193,72 +183,72 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -193,72 +183,72 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
policy
=
tf
.
float32
policy
=
tf
.
float32
# TFLongformerIntermediate.intermediate_act_fn
# TFLongformerIntermediate.intermediate_act_fn
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_activation
,
dtype
=
policy
)
# ???
# ???
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
rate
=
self
.
_inner_dropout
)
# TFLongformerOutput
# TFLongformerOutput
# TFLongformerOutput.dense
# TFLongformerOutput.dense
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
(
None
,
hidden_size
),
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
name
=
"output"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
**
common_kwargs
)
**
common_kwargs
)
# TFLongformerOutput.dropout
# TFLongformerOutput.dropout
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
# TFLongformerOutput.layernorm
# TFLongformerOutput.layernorm
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
name
=
"output_layer_norm"
,
axis
=-
1
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
super
(
LongformerEncoderBlock
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
config
=
{
"num_attention_heads"
:
"num_attention_heads"
:
self
.
_num_heads
,
self
.
_num_heads
,
"inner_dim"
:
"inner_dim"
:
self
.
_inner_dim
,
self
.
_inner_dim
,
"inner_activation"
:
"inner_activation"
:
self
.
_inner_activation
,
self
.
_inner_activation
,
"output_dropout"
:
"output_dropout"
:
self
.
_output_dropout_rate
,
self
.
_output_dropout_rate
,
"attention_dropout"
:
"attention_dropout"
:
self
.
_attention_dropout_rate
,
self
.
_attention_dropout_rate
,
"output_range"
:
"output_range"
:
self
.
_output_range
,
self
.
_output_range
,
"kernel_initializer"
:
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
"use_bias"
:
self
.
_use_bias
,
self
.
_use_bias
,
"norm_first"
:
"norm_first"
:
self
.
_norm_first
,
self
.
_norm_first
,
"norm_epsilon"
:
"norm_epsilon"
:
self
.
_norm_epsilon
,
self
.
_norm_epsilon
,
"inner_dropout"
:
"inner_dropout"
:
self
.
_inner_dropout
,
self
.
_inner_dropout
,
"attention_initializer"
:
"attention_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
"attention_axes"
:
self
.
_attention_axes
,
}
}
base_config
=
super
(
LongformerEncoderBlock
,
self
).
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
...
@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
An output tensor with the same dimensions as input/query tensor.
An output tensor with the same dimensions as input/query tensor.
"""
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
5
:
if
len
(
inputs
)
==
4
:
(
(
input_tensor
,
input_tensor
,
attention_mask
,
attention_mask
,
is_index_masked
,
is_index_masked
,
is_index_global_attn
,
is_index_global_attn
,
is_global_attn
)
=
inputs
)
=
inputs
key_value
=
None
key_value
=
None
elif
len
(
inputs
)
==
6
:
elif
len
(
inputs
)
==
5
:
assert
False
# No key_value
assert
False
# No key_value
else
:
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
raise
ValueError
(
f
"Unexpected inputs to
{
self
.
__class__
}
with length at
{
len
(
inputs
)
}
"
)
(
self
.
__class__
,
len
(
inputs
)))
else
:
else
:
input_tensor
=
inputs
input_tensor
=
inputs
attention_mask
=
None
attention_mask
=
None
is_index_masked
=
None
is_index_masked
=
None
is_index_global_attn
=
None
is_index_global_attn
=
None
is_global_attn
=
None
key_value
=
None
key_value
=
None
if
self
.
_output_range
:
if
self
.
_output_range
:
...
@@ -325,11 +312,10 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -325,11 +312,10 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
# attention_output = self._attention_layer(
# attention_output = self._attention_layer(
# query=target_tensor, value=key_value, attention_mask=attention_mask)
# query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output
=
self
.
_attention_layer
(
attention_output
=
self
.
_attention_layer
(
hidden_states
=
target_tensor
,
hidden_states
=
target_tensor
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
is_index_global_attn
=
is_index_global_attn
,
is_global_attn
=
is_global_attn
)
)
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
...
...
official/projects/longformer/longformer_encoder_test.py
View file @
09e6e71c
...
@@ -12,44 +12,55 @@
...
@@ -12,44 +12,55 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tests for official.nlp.projects.
bigbird.
encoder."""
"""Tests for official.nlp.projects.
longformer.longformer_
encoder."""
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
LongformerEncoderTest
(
keras_parameterized
.
TestCase
):
class
LongformerEncoderTest
(
keras_parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerEncoderTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
@
combinations
.
generate
(
combinations
.
combine
(
@
combinations
.
generate
(
combinations
.
combine
(
attention_window
=
[
32
,
128
],
global_attention_size
=
[
0
,
1
,
2
]))
attention_window
=
[
32
,
128
],
global_attention_size
=
[
0
,
1
,
2
]))
def
test_encoder
(
self
,
attention_window
,
global_attention_size
):
def
test_encoder
(
self
,
attention_window
,
global_attention_size
):
sequence_length
=
128
sequence_length
=
128
batch_size
=
2
batch_size
=
2
vocab_size
=
1024
vocab_size
=
1024
hidden_size
=
256
hidden_size
=
256
network
=
LongformerEncoder
(
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
attention_window
=
attention_window
,
attention_window
=
[
attention_window
],
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_layers
=
1
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
max_sequence_length
=
512
)
max_sequence_length
=
512
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
size
=
(
batch_size
,
sequence_length
),
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
'input_type_ids'
:
type_id_data
,
}
}
outputs
=
network
(
inputs
)
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
"
sequence_output
"
].
shape
,
self
.
assertEqual
(
outputs
[
'
sequence_output
'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
(
batch_size
,
sequence_length
,
hidden_size
))
@
combinations
.
generate
(
combinations
.
combine
(
@
combinations
.
generate
(
combinations
.
combine
(
...
@@ -60,26 +71,30 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
...
@@ -60,26 +71,30 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
vocab_size
=
1024
vocab_size
=
1024
hidden_size
=
256
hidden_size
=
256
network
=
LongformerEncoder
(
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
attention_window
=
32
,
attention_window
=
[
32
],
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_layers
=
1
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
max_sequence_length
=
512
,
max_sequence_length
=
512
,
norm_first
=
norm_first
)
norm_first
=
norm_first
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
size
=
(
batch_size
,
sequence_length
),
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
'input_type_ids'
:
type_id_data
,
}
}
outputs
=
network
(
inputs
)
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
"
sequence_output
"
].
shape
,
self
.
assertEqual
(
outputs
[
'
sequence_output
'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
(
batch_size
,
sequence_length
,
hidden_size
))
if
__name__
==
"__main__"
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
\ No newline at end of file
official/projects/longformer/longformer_experiments.py
View file @
09e6e71c
...
@@ -34,84 +34,90 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
...
@@ -34,84 +34,90 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
LongformerOptimizationConfig
(
optimization
.
OptimizationConfig
):
class
LongformerOptimizationConfig
(
optimization
.
OptimizationConfig
):
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
type
=
"
adamw
"
,
type
=
'
adamw
'
,
adamw
=
AdamWeightDecay
(
adamw
=
AdamWeightDecay
(
weight_decay_rate
=
0.01
,
weight_decay_rate
=
0.01
,
exclude_from_weight_decay
=
[
"
LayerNorm
"
,
"
layer_norm
"
,
"
bias
"
],
exclude_from_weight_decay
=
[
'
LayerNorm
'
,
'
layer_norm
'
,
'
bias
'
],
epsilon
=
1e-6
))
epsilon
=
1e-6
))
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
type
=
"
polynomial
"
,
type
=
'
polynomial
'
,
polynomial
=
PolynomialLr
(
polynomial
=
PolynomialLr
(
initial_learning_rate
=
1e-4
,
initial_learning_rate
=
1e-4
,
decay_steps
=
1000000
,
decay_steps
=
1000000
,
end_learning_rate
=
0.0
))
end_learning_rate
=
0.0
))
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
type
=
'polynomial'
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
@
exp_factory
.
register_config_factory
(
'longformer/pretraining'
)
@
exp_factory
.
register_config_factory
(
'longformer/pretraining'
)
def
longformer_pretraining
()
->
cfg
.
ExperimentConfig
:
def
longformer_pretraining
()
->
cfg
.
ExperimentConfig
:
"""BERT pretraining experiment."""
"""BERT pretraining experiment."""
config
=
cfg
.
ExperimentConfig
(
config
=
cfg
.
ExperimentConfig
(
runtime
=
cfg
.
RuntimeConfig
(
enable_xla
=
True
),
runtime
=
cfg
.
RuntimeConfig
(
enable_xla
=
True
),
task
=
masked_lm
.
MaskedLMConfig
(
task
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
"any"
,
any
=
LongformerEncoderConfig
()),
type
=
"any"
,
any
=
LongformerEncoderConfig
()),
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
]
name
=
'next_sentence'
)
),
]
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
),
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
,
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
use_v2_feature_names
=
True
),
trainer
=
cfg
.
TrainerConfig
(
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
optimizer_config
=
LongformerOptimizationConfig
(),
train_steps
=
1000000
),
use_v2_feature_names
=
True
,
restrictions
=
[
is_training
=
False
)),
'task.train_data.is_training != None'
,
trainer
=
cfg
.
TrainerConfig
(
'task.validation_data.is_training != None'
optimizer_config
=
LongformerOptimizationConfig
(),
train_steps
=
1000000
),
])
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
return
config
@
exp_factory
.
register_config_factory
(
'longformer/glue'
)
@
exp_factory
.
register_config_factory
(
'longformer/glue'
)
def
longformer_glue
()
->
cfg
.
ExperimentConfig
:
def
longformer_glue
()
->
cfg
.
ExperimentConfig
:
config
=
cfg
.
ExperimentConfig
(
config
=
cfg
.
ExperimentConfig
(
task
=
sentence_prediction
.
SentencePredictionConfig
(
task
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
"any"
,
any
=
LongformerEncoderConfig
())),
type
=
"any"
,
any
=
LongformerEncoderConfig
())),
train_data
=
sentence_prediction_dataloader
train_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(),
.
SentencePredictionDataConfig
(),
validation_data
=
sentence_prediction_dataloader
validation_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
.
SentencePredictionDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'adamw'
,
'type'
:
'adamw'
,
'adamw'
:
{
'adamw'
:
{
'weight_decay_rate'
:
'weight_decay_rate'
:
0.01
,
0.01
,
'exclude_from_weight_decay'
:
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'polynomial'
,
'type'
:
'polynomial'
,
'polynomial'
:
{
'polynomial'
:
{
'initial_learning_rate'
:
3e-5
,
'initial_learning_rate'
:
3e-5
,
'end_learning_rate'
:
0.0
,
'end_learning_rate'
:
0.0
,
}
}
},
},
'warmup'
:
{
'warmup'
:
{
'type'
:
'polynomial'
'type'
:
'polynomial'
}
}
})),
})),
restrictions
=
[
restrictions
=
[
'task.train_data.is_training != None'
,
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
'task.validation_data.is_training != None'
])
])
return
config
return
config
official/projects/longformer/train.py
View file @
09e6e71c
...
@@ -24,7 +24,6 @@ from official.core import task_factory
...
@@ -24,7 +24,6 @@ from official.core import task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.projects.longformer
import
longformer_experiments
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -43,23 +42,24 @@ def main(_):
...
@@ -43,23 +42,24 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
,
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
task
=
task
,
task
=
task
,
mode
=
FLAGS
.
mode
,
mode
=
FLAGS
.
mode
,
params
=
params
,
params
=
params
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment