Commit 09e6e71c authored by Zihan Wang's avatar Zihan Wang
Browse files

lint

parent 32867f40
...@@ -12,25 +12,26 @@ ...@@ -12,25 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for the attention layer.""" """Tests for official.nlp.projects.longformer.longformer_attention."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.projects.longformer import longformer_attention from official.projects.longformer import longformer_attention
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
def _create_mock_attention_data( def _create_mock_attention_data(
num_heads, num_heads,
key_dim, key_dim,
value_dim, value_dim,
q_seq_length, q_seq_length,
kv_seq_length, kv_seq_length,
batch_size, batch_size,
include_mask=False): include_mask=False):
"""Creates mock testing data. """Creates mock testing data.
Args: Args:
...@@ -48,15 +49,15 @@ def _create_mock_attention_data( ...@@ -48,15 +49,15 @@ def _create_mock_attention_data(
value_shape = (batch_size, kv_seq_length, value_dim) value_shape = (batch_size, kv_seq_length, value_dim)
data = dict( data = dict(
query=tf.random.normal(shape=query_shape), query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape), value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape)) key=tf.random.normal(shape=value_shape))
total_seq_length = kv_seq_length total_seq_length = kv_seq_length
if include_mask: if include_mask:
mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length) mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32") mask_data = np.random.randint(2, size=mask_shape).astype('float32')
mask_data = dict(attention_mask=mask_data) mask_data = dict(attention_mask=mask_data)
data.update(mask_data) data.update(mask_data)
...@@ -65,6 +66,12 @@ def _create_mock_attention_data( ...@@ -65,6 +66,12 @@ def _create_mock_attention_data(
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class LongformerAttentionTest(keras_parameterized.TestCase): class LongformerAttentionTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerAttentionTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
def _get_hidden_states(self): def _get_hidden_states(self):
return tf.convert_to_tensor( return tf.convert_to_tensor(
[ [
...@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def test_diagonalize(self): def test_diagonalize(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4 hidden_states = tf.reshape(hidden_states,
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
window_overlap_size = get_shape_list(chunked_hidden_states)[2] window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4) self.assertTrue(window_overlap_size == 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states) padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(
chunked_hidden_states)
self.assertTrue( self.assertTrue(
get_shape_list(padded_hidden_states)[-1] == get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1 get_shape_list(padded_hidden_states)[-1] ==
get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
) )
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4],
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3) chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629] # last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:],
chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near( tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3 padded_hidden_states[0, 0, -1, :3],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
) )
def test_pad_and_transpose_last_two_dims(self): def test_pad_and_transpose_last_two_dims(self):
...@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4]) self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim # pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]],
dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) hidden_states = longformer_attention.LongformerAttention._chunk(
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(
hidden_states, paddings)
self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5]) self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near(expected_added_dim,
padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near( tf.debugging.assert_near(
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 hidden_states[0, 0, -1, :],
tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
) )
def test_mask_invalid_locations(self): def test_mask_invalid_locations(self):
...@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
batch_size = 1 batch_size = 1
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = tf.reshape(hidden_states,
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) (batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 1) hidden_states, window_overlap=2)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2) hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2) hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8) hidden_states, 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24) hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24) hidden_states[:, :, :, :3], 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12) hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
def test_chunk(self): def test_chunk(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size = 1 batch_size = 1
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim # expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32) expected_slice_along_seq_length = tf.convert_to_tensor(
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32) [0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor(
[0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4]) self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3) tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0],
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3) expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0],
expected_slice_along_chunk, rtol=1e-3)
def test_layer_local_attn(self): def test_layer_local_attn(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
layer = longformer_attention.LongformerAttention( layer = longformer_attention.LongformerAttention(
num_heads=2, num_heads=2,
key_dim=4, key_dim=4,
...@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32) attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1) is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0,
attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer( output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=attention_mask, hidden_states=hidden_states, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8)) self.assertTrue(output_hidden_states.shape, (1, 4, 8))
...@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
) )
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0) hidden_states = tf.concat(
[self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask # create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_1) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1) attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_2) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0,
attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer( output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask), hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, is_index_masked=is_index_masked,
)[0] is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertTrue(output_hidden_states.shape, (2, 4, 8))
if __name__ == "__main__": if __name__ == '__main__':
np.random.seed(0)
tf.random.set_seed(0)
tf.test.main() tf.test.main()
...@@ -23,29 +23,16 @@ from absl import logging ...@@ -23,29 +23,16 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock from official.projects.longformer.longformer_encoder_block import \
LongformerEncoderBlock
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True) _approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
# Transferred from huggingface.longformer.TFLongformerMainLayer & TFLongformerEncoder
class LongformerEncoder(tf.keras.layers.Layer): class LongformerEncoder(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network. """LongformerEncoder
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Args: Args:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer. attention_window: list of ints representing the window size for each layer.
...@@ -85,27 +72,27 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -85,27 +72,27 @@ class LongformerEncoder(tf.keras.layers.Layer):
""" """
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
attention_window: Union[List[int], int] = 512, attention_window: Union[List[int], int] = 512,
global_attention_size: int = 0, global_attention_size: int = 0,
pad_token_id: int = 1, pad_token_id: int = 1,
hidden_size: int = 768, hidden_size: int = 768,
num_layers: int = 12, num_layers: int = 12,
num_attention_heads: int = 12, num_attention_heads: int = 12,
max_sequence_length: int = 512, max_sequence_length: int = 512,
type_vocab_size: int = 16, type_vocab_size: int = 16,
inner_dim: int = 3072, inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu, inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1, output_dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal( initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
output_range: Optional[int] = None, output_range: Optional[int] = None,
embedding_width: Optional[int] = None, embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False, norm_first: bool = False,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# Longformer args # Longformer args
self._attention_window = attention_window self._attention_window = attention_window
...@@ -120,93 +107,91 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -120,93 +107,91 @@ class LongformerEncoder(tf.keras.layers.Layer):
if embedding_layer is None: if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
name='word_embeddings') name='word_embeddings')
else: else:
self._embedding_layer = embedding_layer self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization( self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout( self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout') rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
self._embedding_projection = None self._embedding_projection = None
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=initializer,
name='embedding_projection') name='embedding_projection')
self._transformer_layers = [] self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask( self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask') name='self_attention_mask')
for i in range(num_layers): for i in range(num_layers):
layer = LongformerEncoderBlock( layer = LongformerEncoderBlock(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=inner_dim, inner_dim=inner_dim,
inner_activation=inner_activation, inner_activation=inner_activation,
# Longformer, instead of passing a list of attention_window, pass a value to sub-block attention_window=attention_window[i],
attention_window=attention_window if isinstance(attention_window, int) else attention_window[i], layer_id=i,
layer_id=i, output_dropout=output_dropout,
output_dropout=output_dropout, attention_dropout=attention_dropout,
attention_dropout=attention_dropout, norm_first=norm_first,
norm_first=norm_first, output_range=output_range if i == num_layers - 1 else None,
output_range=output_range if i == num_layers - 1 else None, kernel_initializer=initializer,
kernel_initializer=initializer, name=f'transformer/layer_{i}')
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
name='pooler_transform') name='pooler_transform')
self._config = { self._config = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length, 'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim, 'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation), 'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout, 'output_dropout': output_dropout,
'attention_dropout': attention_dropout, 'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range, 'output_range': output_range,
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first, 'norm_first': norm_first,
# Longformer 'attention_window': attention_window,
'attention_window': attention_window, 'global_attention_size': global_attention_size,
'global_attention_size': global_attention_size, 'pad_token_id': pad_token_id,
'pad_token_id': pad_token_id,
} }
self.inputs = dict( self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32), input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32)) input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs): def call(self, inputs):
word_embeddings = None word_embeddings = None
...@@ -214,22 +199,23 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -214,22 +199,23 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids') # input_ids word_ids = inputs.get('input_word_ids') # input_ids
mask = inputs.get('input_mask') # attention_mask mask = inputs.get('input_mask') # attention_mask
type_ids = inputs.get('input_type_ids') # token_type_ids type_ids = inputs.get('input_type_ids') # token_type_ids
word_embeddings = inputs.get('input_word_embeddings', None) # input_embeds word_embeddings = inputs.get('input_word_embeddings',
None) # input_embeds
else: else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__) raise ValueError(f'Unexpected inputs type to {self.__class__}.')
( (
padding_len, padding_len,
word_ids, word_ids,
mask, mask,
type_ids, type_ids,
word_embeddings, word_embeddings,
) = self._pad_to_window_size( ) = self._pad_to_window_size(
word_ids=word_ids, word_ids=word_ids,
mask=mask, mask=mask,
type_ids=type_ids, type_ids=type_ids,
word_embeddings=word_embeddings, word_embeddings=word_embeddings,
pad_token_id=self._pad_token_id pad_token_id=self._pad_token_id
) )
if word_embeddings is None: if word_embeddings is None:
...@@ -247,46 +233,47 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -247,46 +233,47 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size, seq_len = get_shape_list(mask) batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size # create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2, mask = tf.transpose(tf.concat(
tf.transpose(mask)[self._global_attention_size:]], axis=0)) values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0))
is_index_masked = tf.math.less(mask, 1) is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(tf.concat(values=[ is_index_global_attn = tf.transpose(tf.concat(values=[
tf.ones((self._global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self._global_attention_size, tf.ones((self._global_attention_size, batch_size), tf.bool),
batch_size), tf.bool) tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool)
], axis=0)) ], axis=0))
is_global_attn = self._global_attention_size > 0
# Longformer # Longformer
attention_mask = mask attention_mask = mask
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1) attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1)
) )
attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask),
tf.dtypes.float32) * -10000.0
encoder_outputs = [] encoder_outputs = []
x = embeddings x = embeddings
# TFLongformerEncoder # TFLongformerEncoder
for i, layer in enumerate(self._transformer_layers): for layer in self._transformer_layers:
x = layer([ x = layer([
x, x,
attention_mask, attention_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn])
is_global_attn])
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
if padding_len > 0: if padding_len > 0:
last_encoder_output = last_encoder_output[:, :-padding_len] last_encoder_output = last_encoder_output[:, :-padding_len]
first_token_tensor = last_encoder_output[:, 0, :] first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor) pooled_output = self._pooler_layer(first_token_tensor)
return dict( return dict(
sequence_output=last_encoder_output, sequence_output=last_encoder_output,
pooled_output=pooled_output, pooled_output=pooled_output,
encoder_outputs=encoder_outputs) encoder_outputs=encoder_outputs)
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
...@@ -311,36 +298,36 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -311,36 +298,36 @@ class LongformerEncoder(tf.keras.layers.Layer):
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None: if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = ( warn_string = (
'You are reloading a model that was saved with a ' 'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to ' 'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. ' 'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.') 'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string) print('WARNING: ' + warn_string)
logging.warn(warn_string) logging.warn(warn_string)
return cls(**config) return cls(**config)
def _pad_to_window_size( def _pad_to_window_size(
self, self,
word_ids, word_ids,
mask, mask,
type_ids, type_ids,
word_embeddings, word_embeddings,
pad_token_id, pad_token_id,
): ):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding # padding
attention_window = ( attention_window = max(self._attention_window)
self._attention_window if isinstance(self._attention_window, int) else max(self._attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, \
f'`attention_window` should be an even value. Given {attention_window}'
input_shape = get_shape_list(word_ids) if word_ids is not None else get_shape_list(word_embeddings) input_shape = get_shape_list(
word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
if seq_len is not None: if seq_len is not None:
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window -
seq_len % attention_window) % attention_window
else: else:
padding_len = 0 padding_len = 0
...@@ -355,14 +342,17 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -355,14 +342,17 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings_padding = self._embedding_layer(word_ids_padding) word_embeddings_padding = self._embedding_layer(word_ids_padding)
return tf.concat([word_embeddings, word_embeddings_padding], axis=-2) return tf.concat([word_embeddings, word_embeddings_padding], axis=-2)
word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: word_embeddings) word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings,
lambda: word_embeddings)
mask = tf.pad(mask, paddings, constant_values=False) # no attention on the padding tokens mask = tf.pad(mask, paddings,
token_type_ids = tf.pad(type_ids, paddings, constant_values=0) # pad with token_type_id = 0 constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(type_ids, paddings,
constant_values=0) # pad with token_type_id = 0
return ( return (
padding_len, padding_len,
word_ids, word_ids,
mask, mask,
token_type_ids, token_type_ids,
word_embeddings,) word_embeddings,)
...@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers ...@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
""" """
import tensorflow as tf import tensorflow as tf
from official.projects.longformer.longformer_attention import LongformerAttention from official.projects.longformer.longformer_attention import \
LongformerAttention
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class LongformerEncoderBlock(tf.keras.layers.Layer): class LongformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer. """LongformerEncoderBlock.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
...@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/ **kwargs: keyword arguments/
""" """
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.global_attention_size = global_attention_size self.global_attention_size = global_attention_size
...@@ -121,7 +111,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -121,7 +111,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dropout = inner_dropout self._inner_dropout = inner_dropout
if attention_initializer: if attention_initializer:
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes self._attention_axes = attention_axes
...@@ -133,58 +123,58 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -133,58 +123,58 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
input_tensor_shape = tf.TensorShape(input_shape[0]) input_tensor_shape = tf.TensorShape(input_shape[0])
else: else:
raise ValueError( raise ValueError(
"The type of input shape argument is not supported, got: %s" % f"The type of input shape argument is not supported, got: "
type(input_shape)) f"{type(input_shape)}")
einsum_equation = "abc,cd->abd" einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3: if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd" einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1] hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( raise ValueError(
"The input size (%d) is not a multiple of the number of attention " f"The input size ({hidden_size}) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) f"heads ({self._num_heads})")
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense # TFLongformerSelfAttention + TFLongformerSelfOutput.dense
self._attention_layer = LongformerAttention( self._attention_layer = LongformerAttention(
# Longformer # Longformer
layer_id=self._layer_id, layer_id=self._layer_id,
global_attention_size=self.global_attention_size, global_attention_size=self.global_attention_size,
attention_window=self._attention_window, attention_window=self._attention_window,
num_heads=self._num_heads, num_heads=self._num_heads,
key_dim=self._attention_head_size, key_dim=self._attention_head_size,
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
# TFLongformerSelfOutput.dropout # TFLongformerSelfOutput.dropout
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
# TFLongformerSelfOutput.Layernorm # TFLongformerSelfOutput.Layernorm
self._attention_layer_norm = ( self._attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
# TFLongformerIntermediate # TFLongformerIntermediate
# TFLongformerIntermediate.dense # TFLongformerIntermediate.dense
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -193,72 +183,72 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -193,72 +183,72 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
# TFLongformerIntermediate.intermediate_act_fn # TFLongformerIntermediate.intermediate_act_fn
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
# ??? # ???
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
# TFLongformerOutput # TFLongformerOutput
# TFLongformerOutput.dense # TFLongformerOutput.dense
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
**common_kwargs) **common_kwargs)
# TFLongformerOutput.dropout # TFLongformerOutput.dropout
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# TFLongformerOutput.layernorm # TFLongformerOutput.layernorm
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", name="output_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
super(LongformerEncoderBlock, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
"num_attention_heads": "num_attention_heads":
self._num_heads, self._num_heads,
"inner_dim": "inner_dim":
self._inner_dim, self._inner_dim,
"inner_activation": "inner_activation":
self._inner_activation, self._inner_activation,
"output_dropout": "output_dropout":
self._output_dropout_rate, self._output_dropout_rate,
"attention_dropout": "attention_dropout":
self._attention_dropout_rate, self._attention_dropout_rate,
"output_range": "output_range":
self._output_range, self._output_range,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer), tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer": "kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer), tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer": "bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer), tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer": "activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer), tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint": "kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint), tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint), tf.keras.constraints.serialize(self._bias_constraint),
"use_bias": "use_bias":
self._use_bias, self._use_bias,
"norm_first": "norm_first":
self._norm_first, self._norm_first,
"norm_epsilon": "norm_epsilon":
self._norm_epsilon, self._norm_epsilon,
"inner_dropout": "inner_dropout":
self._inner_dropout, self._inner_dropout,
"attention_initializer": "attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer), tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes, "attention_axes": self._attention_axes,
} }
base_config = super(LongformerEncoderBlock, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
...@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
An output tensor with the same dimensions as input/query tensor. An output tensor with the same dimensions as input/query tensor.
""" """
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
if len(inputs) == 5: if len(inputs) == 4:
( (
input_tensor, input_tensor,
attention_mask, attention_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn
) = inputs ) = inputs
key_value = None key_value = None
elif len(inputs) == 6: elif len(inputs) == 5:
assert False # No key_value assert False # No key_value
else: else:
raise ValueError("Unexpected inputs to %s with length at %d" % raise ValueError(f"Unexpected inputs to {self.__class__} with length at {len(inputs)}")
(self.__class__, len(inputs)))
else: else:
input_tensor = inputs input_tensor = inputs
attention_mask = None attention_mask = None
is_index_masked = None is_index_masked = None
is_index_global_attn = None is_index_global_attn = None
is_global_attn = None
key_value = None key_value = None
if self._output_range: if self._output_range:
...@@ -325,11 +312,10 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -325,11 +312,10 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
# attention_output = self._attention_layer( # attention_output = self._attention_layer(
# query=target_tensor, value=key_value, attention_mask=attention_mask) # query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_layer( attention_output = self._attention_layer(
hidden_states=target_tensor, hidden_states=target_tensor,
attention_mask=attention_mask, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn
) )
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense} # TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
......
...@@ -12,44 +12,55 @@ ...@@ -12,44 +12,55 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for official.nlp.projects.bigbird.encoder.""" """Tests for official.nlp.projects.longformer.longformer_encoder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from official.projects.longformer.longformer_encoder import LongformerEncoder from official.projects.longformer.longformer_encoder import LongformerEncoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class LongformerEncoderTest(keras_parameterized.TestCase): class LongformerEncoderTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerEncoderTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
@combinations.generate(combinations.combine( @combinations.generate(combinations.combine(
attention_window=[32, 128], global_attention_size=[0, 1, 2])) attention_window=[32, 128], global_attention_size=[0, 1, 2]))
def test_encoder(self, attention_window, global_attention_size): def test_encoder(self, attention_window, global_attention_size):
sequence_length = 128 sequence_length = 128
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
hidden_size=256 hidden_size = 256
network = LongformerEncoder( network = LongformerEncoder(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
vocab_size=vocab_size, vocab_size=vocab_size,
attention_window=attention_window, attention_window=[attention_window],
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=1, num_layers=1,
num_attention_heads=4, num_attention_heads=4,
max_sequence_length=512) max_sequence_length=512)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32) word_id_data = np.random.randint(vocab_size,
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) size=(batch_size, sequence_length),
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = { inputs = {
'input_word_ids': word_id_data, 'input_word_ids': word_id_data,
'input_mask': mask_data, 'input_mask': mask_data,
'input_type_ids': type_id_data, 'input_type_ids': type_id_data,
} }
outputs = network(inputs) outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape, self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
@combinations.generate(combinations.combine( @combinations.generate(combinations.combine(
...@@ -60,26 +71,30 @@ class LongformerEncoderTest(keras_parameterized.TestCase): ...@@ -60,26 +71,30 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
vocab_size = 1024 vocab_size = 1024
hidden_size = 256 hidden_size = 256
network = LongformerEncoder( network = LongformerEncoder(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
vocab_size=vocab_size, vocab_size=vocab_size,
attention_window=32, attention_window=[32],
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=1, num_layers=1,
num_attention_heads=4, num_attention_heads=4,
max_sequence_length=512, max_sequence_length=512,
norm_first=norm_first) norm_first=norm_first)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32) word_id_data = np.random.randint(vocab_size,
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) size=(batch_size, sequence_length),
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = { inputs = {
'input_word_ids': word_id_data, 'input_word_ids': word_id_data,
'input_mask': mask_data, 'input_mask': mask_data,
'input_type_ids': type_id_data, 'input_type_ids': type_id_data,
} }
outputs = network(inputs) outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape, self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
...@@ -34,84 +34,90 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig ...@@ -34,84 +34,90 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig PolynomialLr = optimization.PolynomialLrConfig
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
@dataclasses.dataclass @dataclasses.dataclass
class LongformerOptimizationConfig(optimization.OptimizationConfig): class LongformerOptimizationConfig(optimization.OptimizationConfig):
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig( optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type="adamw", type='adamw',
adamw=AdamWeightDecay( adamw=AdamWeightDecay(
weight_decay_rate=0.01, weight_decay_rate=0.01,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
epsilon=1e-6)) epsilon=1e-6))
learning_rate: optimization.LrConfig = optimization.LrConfig( learning_rate: optimization.LrConfig = optimization.LrConfig(
type="polynomial", type='polynomial',
polynomial=PolynomialLr( polynomial=PolynomialLr(
initial_learning_rate=1e-4, initial_learning_rate=1e-4,
decay_steps=1000000, decay_steps=1000000,
end_learning_rate=0.0)) end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig( warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type="polynomial", polynomial=PolynomialWarmupConfig(warmup_steps=10000)) type='polynomial', polynomial=PolynomialWarmupConfig(warmup_steps=10000))
@exp_factory.register_config_factory('longformer/pretraining') @exp_factory.register_config_factory('longformer/pretraining')
def longformer_pretraining() -> cfg.ExperimentConfig: def longformer_pretraining() -> cfg.ExperimentConfig:
"""BERT pretraining experiment.""" """BERT pretraining experiment."""
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True), runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.MaskedLMConfig( task=masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig( model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig( encoder=encoders.EncoderConfig(
type="any", any=LongformerEncoderConfig()), type="any", any=LongformerEncoderConfig()),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1,
] name='next_sentence')
), ]
train_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True), ),
validation_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True, train_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)), use_v2_feature_names=True),
trainer=cfg.TrainerConfig( validation_data=pretrain_dataloader.BertPretrainDataConfig(
optimizer_config=LongformerOptimizationConfig(), train_steps=1000000), use_v2_feature_names=True,
restrictions=[ is_training=False)),
'task.train_data.is_training != None', trainer=cfg.TrainerConfig(
'task.validation_data.is_training != None' optimizer_config=LongformerOptimizationConfig(), train_steps=1000000),
]) restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config return config
@exp_factory.register_config_factory('longformer/glue') @exp_factory.register_config_factory('longformer/glue')
def longformer_glue() -> cfg.ExperimentConfig: def longformer_glue() -> cfg.ExperimentConfig:
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig( task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig( model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig( encoder=encoders.EncoderConfig(
type="any", any=LongformerEncoderConfig())), type="any", any=LongformerEncoderConfig())),
train_data=sentence_prediction_dataloader train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(), .SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig( .SentencePredictionDataConfig(
is_training=False, drop_remainder=False)), is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({ optimizer_config=optimization.OptimizationConfig({
'optimizer': { 'optimizer': {
'type': 'adamw', 'type': 'adamw',
'adamw': { 'adamw': {
'weight_decay_rate': 'weight_decay_rate':
0.01, 0.01,
'exclude_from_weight_decay': 'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'], ['LayerNorm', 'layer_norm', 'bias'],
} }
}, },
'learning_rate': { 'learning_rate': {
'type': 'polynomial', 'type': 'polynomial',
'polynomial': { 'polynomial': {
'initial_learning_rate': 3e-5, 'initial_learning_rate': 3e-5,
'end_learning_rate': 0.0, 'end_learning_rate': 0.0,
} }
}, },
'warmup': { 'warmup': {
'type': 'polynomial' 'type': 'polynomial'
} }
})), })),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
return config return config
...@@ -24,7 +24,6 @@ from official.core import task_factory ...@@ -24,7 +24,6 @@ from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.projects.longformer import longformer_experiments
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -43,23 +42,24 @@ def main(_): ...@@ -43,23 +42,24 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) performance.set_mixed_precision_policy(
params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu, tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism()) **params.runtime.model_parallelism())
with distribution_strategy.scope(): with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir) task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment( train_lib.run_experiment(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
task=task, task=task,
mode=FLAGS.mode, mode=FLAGS.mode,
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment