Commit 09e6e71c authored by Zihan Wang's avatar Zihan Wang
Browse files

lint

parent 32867f40
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for the attention layer.""" """Tests for official.nlp.projects.longformer.longformer_attention."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.projects.longformer import longformer_attention from official.projects.longformer import longformer_attention
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
...@@ -56,7 +57,7 @@ def _create_mock_attention_data( ...@@ -56,7 +57,7 @@ def _create_mock_attention_data(
if include_mask: if include_mask:
mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length) mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32") mask_data = np.random.randint(2, size=mask_shape).astype('float32')
mask_data = dict(attention_mask=mask_data) mask_data = dict(attention_mask=mask_data)
data.update(mask_data) data.update(mask_data)
...@@ -65,6 +66,12 @@ def _create_mock_attention_data( ...@@ -65,6 +66,12 @@ def _create_mock_attention_data(
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class LongformerAttentionTest(keras_parameterized.TestCase): class LongformerAttentionTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerAttentionTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
def _get_hidden_states(self): def _get_hidden_states(self):
return tf.convert_to_tensor( return tf.convert_to_tensor(
[ [
...@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def test_diagonalize(self): def test_diagonalize(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4 hidden_states = tf.reshape(hidden_states,
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
window_overlap_size = get_shape_list(chunked_hidden_states)[2] window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4) self.assertTrue(window_overlap_size == 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states) padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(
chunked_hidden_states)
self.assertTrue( self.assertTrue(
get_shape_list(padded_hidden_states)[-1] == get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1 get_shape_list(padded_hidden_states)[-1] ==
get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
) )
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4],
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3) chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629] # last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:],
chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near( tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3 padded_hidden_states[0, 0, -1, :3],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
) )
def test_pad_and_transpose_last_two_dims(self): def test_pad_and_transpose_last_two_dims(self):
...@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4]) self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim # pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]],
dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) hidden_states = longformer_attention.LongformerAttention._chunk(
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(
hidden_states, paddings)
self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5]) self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near(expected_added_dim,
padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near( tf.debugging.assert_near(
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 hidden_states[0, 0, -1, :],
tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
) )
def test_mask_invalid_locations(self): def test_mask_invalid_locations(self):
...@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
batch_size = 1 batch_size = 1
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = tf.reshape(hidden_states,
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) (batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 1) hidden_states, window_overlap=2)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2) hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2) hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8) hidden_states, 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24) hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24) hidden_states[:, :, :, :3], 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12) hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
def test_chunk(self): def test_chunk(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size = 1 batch_size = 1
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim # expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32) expected_slice_along_seq_length = tf.convert_to_tensor(
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32) [0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor(
[0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4]) self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3) tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0],
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3) expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0],
expected_slice_along_chunk, rtol=1e-3)
def test_layer_local_attn(self): def test_layer_local_attn(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
layer = longformer_attention.LongformerAttention( layer = longformer_attention.LongformerAttention(
num_heads=2, num_heads=2,
key_dim=4, key_dim=4,
...@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32) attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1) is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0,
attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer( output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=attention_mask, hidden_states=hidden_states, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8)) self.assertTrue(output_hidden_states.shape, (1, 4, 8))
...@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
) )
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0) hidden_states = tf.concat(
[self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask # create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_1) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1) attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_2) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0,
attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer( output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask), hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertTrue(output_hidden_states.shape, (2, 4, 8))
if __name__ == "__main__": if __name__ == '__main__':
np.random.seed(0)
tf.random.set_seed(0)
tf.test.main() tf.test.main()
...@@ -23,29 +23,16 @@ from absl import logging ...@@ -23,29 +23,16 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock from official.projects.longformer.longformer_encoder_block import \
LongformerEncoderBlock
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True) _approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
# Transferred from huggingface.longformer.TFLongformerMainLayer & TFLongformerEncoder
class LongformerEncoder(tf.keras.layers.Layer): class LongformerEncoder(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network. """LongformerEncoder
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Args: Args:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer. attention_window: list of ints representing the window size for each layer.
...@@ -165,15 +152,14 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -165,15 +152,14 @@ class LongformerEncoder(tf.keras.layers.Layer):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=inner_dim, inner_dim=inner_dim,
inner_activation=inner_activation, inner_activation=inner_activation,
# Longformer, instead of passing a list of attention_window, pass a value to sub-block attention_window=attention_window[i],
attention_window=attention_window if isinstance(attention_window, int) else attention_window[i],
layer_id=i, layer_id=i,
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=initializer,
name='transformer/layer_%d' % i) name=f'transformer/layer_{i}')
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
...@@ -198,7 +184,6 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -198,7 +184,6 @@ class LongformerEncoder(tf.keras.layers.Layer):
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first, 'norm_first': norm_first,
# Longformer
'attention_window': attention_window, 'attention_window': attention_window,
'global_attention_size': global_attention_size, 'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id, 'pad_token_id': pad_token_id,
...@@ -214,9 +199,10 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -214,9 +199,10 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids') # input_ids word_ids = inputs.get('input_word_ids') # input_ids
mask = inputs.get('input_mask') # attention_mask mask = inputs.get('input_mask') # attention_mask
type_ids = inputs.get('input_type_ids') # token_type_ids type_ids = inputs.get('input_type_ids') # token_type_ids
word_embeddings = inputs.get('input_word_embeddings', None) # input_embeds word_embeddings = inputs.get('input_word_embeddings',
None) # input_embeds
else: else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__) raise ValueError(f'Unexpected inputs type to {self.__class__}.')
( (
padding_len, padding_len,
...@@ -247,34 +233,35 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -247,34 +233,35 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size, seq_len = get_shape_list(mask) batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size # create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2, mask = tf.transpose(tf.concat(
values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0)) tf.transpose(mask)[self._global_attention_size:]], axis=0))
is_index_masked = tf.math.less(mask, 1) is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(tf.concat(values=[ is_index_global_attn = tf.transpose(tf.concat(values=[
tf.ones((self._global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self._global_attention_size, tf.ones((self._global_attention_size, batch_size), tf.bool),
tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool) batch_size), tf.bool)
], axis=0)) ], axis=0))
is_global_attn = self._global_attention_size > 0
# Longformer # Longformer
attention_mask = mask attention_mask = mask
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1) attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1)
) )
attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask),
tf.dtypes.float32) * -10000.0
encoder_outputs = [] encoder_outputs = []
x = embeddings x = embeddings
# TFLongformerEncoder # TFLongformerEncoder
for i, layer in enumerate(self._transformer_layers): for layer in self._transformer_layers:
x = layer([ x = layer([
x, x,
attention_mask, attention_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn])
is_global_attn])
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
...@@ -328,19 +315,19 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -328,19 +315,19 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings, word_embeddings,
pad_token_id, pad_token_id,
): ):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding # padding
attention_window = ( attention_window = max(self._attention_window)
self._attention_window if isinstance(self._attention_window, int) else max(self._attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, \
f'`attention_window` should be an even value. Given {attention_window}'
input_shape = get_shape_list(word_ids) if word_ids is not None else get_shape_list(word_embeddings) input_shape = get_shape_list(
word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
if seq_len is not None: if seq_len is not None:
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window -
seq_len % attention_window) % attention_window
else: else:
padding_len = 0 padding_len = 0
...@@ -355,10 +342,13 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -355,10 +342,13 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings_padding = self._embedding_layer(word_ids_padding) word_embeddings_padding = self._embedding_layer(word_ids_padding)
return tf.concat([word_embeddings, word_embeddings_padding], axis=-2) return tf.concat([word_embeddings, word_embeddings_padding], axis=-2)
word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: word_embeddings) word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings,
lambda: word_embeddings)
mask = tf.pad(mask, paddings, constant_values=False) # no attention on the padding tokens mask = tf.pad(mask, paddings,
token_type_ids = tf.pad(type_ids, paddings, constant_values=0) # pad with token_type_id = 0 constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(type_ids, paddings,
constant_values=0) # pad with token_type_id = 0
return ( return (
padding_len, padding_len,
......
...@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers ...@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
""" """
import tensorflow as tf import tensorflow as tf
from official.projects.longformer.longformer_attention import LongformerAttention from official.projects.longformer.longformer_attention import \
LongformerAttention
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class LongformerEncoderBlock(tf.keras.layers.Layer): class LongformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer. """LongformerEncoderBlock.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
...@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/ **kwargs: keyword arguments/
""" """
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.global_attention_size = global_attention_size self.global_attention_size = global_attention_size
...@@ -133,16 +123,16 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -133,16 +123,16 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
input_tensor_shape = tf.TensorShape(input_shape[0]) input_tensor_shape = tf.TensorShape(input_shape[0])
else: else:
raise ValueError( raise ValueError(
"The type of input shape argument is not supported, got: %s" % f"The type of input shape argument is not supported, got: "
type(input_shape)) f"{type(input_shape)}")
einsum_equation = "abc,cd->abd" einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3: if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd" einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1] hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( raise ValueError(
"The input size (%d) is not a multiple of the number of attention " f"The input size ({hidden_size}) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) f"heads ({self._num_heads})")
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -216,7 +206,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -216,7 +206,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
super(LongformerEncoderBlock, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
...@@ -258,7 +248,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -258,7 +248,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
tf.keras.initializers.serialize(self._attention_initializer), tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes, "attention_axes": self._attention_axes,
} }
base_config = super(LongformerEncoderBlock, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
...@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
An output tensor with the same dimensions as input/query tensor. An output tensor with the same dimensions as input/query tensor.
""" """
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
if len(inputs) == 5: if len(inputs) == 4:
( (
input_tensor, input_tensor,
attention_mask, attention_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn
) = inputs ) = inputs
key_value = None key_value = None
elif len(inputs) == 6: elif len(inputs) == 5:
assert False # No key_value assert False # No key_value
else: else:
raise ValueError("Unexpected inputs to %s with length at %d" % raise ValueError(f"Unexpected inputs to {self.__class__} with length at {len(inputs)}")
(self.__class__, len(inputs)))
else: else:
input_tensor = inputs input_tensor = inputs
attention_mask = None attention_mask = None
is_index_masked = None is_index_masked = None
is_index_global_attn = None is_index_global_attn = None
is_global_attn = None
key_value = None key_value = None
if self._output_range: if self._output_range:
...@@ -329,7 +316,6 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -329,7 +316,6 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn
) )
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense} # TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
......
...@@ -12,44 +12,55 @@ ...@@ -12,44 +12,55 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for official.nlp.projects.bigbird.encoder.""" """Tests for official.nlp.projects.longformer.longformer_encoder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from official.projects.longformer.longformer_encoder import LongformerEncoder from official.projects.longformer.longformer_encoder import LongformerEncoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class LongformerEncoderTest(keras_parameterized.TestCase): class LongformerEncoderTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerEncoderTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
@combinations.generate(combinations.combine( @combinations.generate(combinations.combine(
attention_window=[32, 128], global_attention_size=[0, 1, 2])) attention_window=[32, 128], global_attention_size=[0, 1, 2]))
def test_encoder(self, attention_window, global_attention_size): def test_encoder(self, attention_window, global_attention_size):
sequence_length = 128 sequence_length = 128
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
hidden_size=256 hidden_size = 256
network = LongformerEncoder( network = LongformerEncoder(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
vocab_size=vocab_size, vocab_size=vocab_size,
attention_window=attention_window, attention_window=[attention_window],
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=1, num_layers=1,
num_attention_heads=4, num_attention_heads=4,
max_sequence_length=512) max_sequence_length=512)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32) word_id_data = np.random.randint(vocab_size,
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) size=(batch_size, sequence_length),
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = { inputs = {
'input_word_ids': word_id_data, 'input_word_ids': word_id_data,
'input_mask': mask_data, 'input_mask': mask_data,
'input_type_ids': type_id_data, 'input_type_ids': type_id_data,
} }
outputs = network(inputs) outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape, self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
@combinations.generate(combinations.combine( @combinations.generate(combinations.combine(
...@@ -62,24 +73,28 @@ class LongformerEncoderTest(keras_parameterized.TestCase): ...@@ -62,24 +73,28 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
network = LongformerEncoder( network = LongformerEncoder(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
vocab_size=vocab_size, vocab_size=vocab_size,
attention_window=32, attention_window=[32],
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=1, num_layers=1,
num_attention_heads=4, num_attention_heads=4,
max_sequence_length=512, max_sequence_length=512,
norm_first=norm_first) norm_first=norm_first)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32) word_id_data = np.random.randint(vocab_size,
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) size=(batch_size, sequence_length),
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = { inputs = {
'input_word_ids': word_id_data, 'input_word_ids': word_id_data,
'input_mask': mask_data, 'input_mask': mask_data,
'input_type_ids': type_id_data, 'input_type_ids': type_id_data,
} }
outputs = network(inputs) outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape, self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -34,22 +34,24 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig ...@@ -34,22 +34,24 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig PolynomialLr = optimization.PolynomialLrConfig
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
@dataclasses.dataclass @dataclasses.dataclass
class LongformerOptimizationConfig(optimization.OptimizationConfig): class LongformerOptimizationConfig(optimization.OptimizationConfig):
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig( optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type="adamw", type='adamw',
adamw=AdamWeightDecay( adamw=AdamWeightDecay(
weight_decay_rate=0.01, weight_decay_rate=0.01,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
epsilon=1e-6)) epsilon=1e-6))
learning_rate: optimization.LrConfig = optimization.LrConfig( learning_rate: optimization.LrConfig = optimization.LrConfig(
type="polynomial", type='polynomial',
polynomial=PolynomialLr( polynomial=PolynomialLr(
initial_learning_rate=1e-4, initial_learning_rate=1e-4,
decay_steps=1000000, decay_steps=1000000,
end_learning_rate=0.0)) end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig( warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type="polynomial", polynomial=PolynomialWarmupConfig(warmup_steps=10000)) type='polynomial', polynomial=PolynomialWarmupConfig(warmup_steps=10000))
@exp_factory.register_config_factory('longformer/pretraining') @exp_factory.register_config_factory('longformer/pretraining')
def longformer_pretraining() -> cfg.ExperimentConfig: def longformer_pretraining() -> cfg.ExperimentConfig:
...@@ -62,11 +64,14 @@ def longformer_pretraining() -> cfg.ExperimentConfig: ...@@ -62,11 +64,14 @@ def longformer_pretraining() -> cfg.ExperimentConfig:
type="any", any=LongformerEncoderConfig()), type="any", any=LongformerEncoderConfig()),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1,
name='next_sentence')
] ]
), ),
train_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True), train_data=pretrain_dataloader.BertPretrainDataConfig(
validation_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True, use_v2_feature_names=True),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
use_v2_feature_names=True,
is_training=False)), is_training=False)),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
optimizer_config=LongformerOptimizationConfig(), train_steps=1000000), optimizer_config=LongformerOptimizationConfig(), train_steps=1000000),
...@@ -76,6 +81,7 @@ def longformer_pretraining() -> cfg.ExperimentConfig: ...@@ -76,6 +81,7 @@ def longformer_pretraining() -> cfg.ExperimentConfig:
]) ])
return config return config
@exp_factory.register_config_factory('longformer/glue') @exp_factory.register_config_factory('longformer/glue')
def longformer_glue() -> cfg.ExperimentConfig: def longformer_glue() -> cfg.ExperimentConfig:
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
......
...@@ -24,7 +24,6 @@ from official.core import task_factory ...@@ -24,7 +24,6 @@ from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.projects.longformer import longformer_experiments
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -43,7 +42,8 @@ def main(_): ...@@ -43,7 +42,8 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) performance.set_mixed_precision_policy(
params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment