Commit 350f4854 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398286699
parent 9d5a1a76
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
"""Funnel Transformer network.""" """Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Union, Collection from typing import Union, Sequence
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp import keras_nlp from official.nlp import keras_nlp
def _pool_and_concat(data, unpool_length: int, stride: int, def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
axes: Union[Collection[int], int]): int],
axes: Union[Sequence[int], int]):
"""Pools the data along a given axis with stride. """Pools the data along a given axis with stride.
It also skips first unpool_length elements. It also skips first unpool_length elements.
...@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int, ...@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
Args: Args:
data: Tensor to be pooled. data: Tensor to be pooled.
unpool_length: Leading elements to be skipped. unpool_length: Leading elements to be skipped.
stride: Stride for the given axis. strides: Strides for the given axes.
axes: Axes to pool the Tensor. axes: Axes to pool the Tensor.
Returns: Returns:
...@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int, ...@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
# Wraps the axes as a list. # Wraps the axes as a list.
if isinstance(axes, int): if isinstance(axes, int):
axes = [axes] axes = [axes]
if isinstance(strides, int):
strides = [strides] * len(axes)
else:
if len(strides) != len(axes):
raise ValueError('The lengths of strides and axes need to match.')
for axis in axes: for axis, stride in zip(axes, strides):
# Skips first `unpool_length` tokens. # Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)] unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = data[unpool_tensor_shape] unpool_tensor = data[unpool_tensor_shape]
...@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout. dropout.
attention_dropout: The dropout rate to use for the attention layers within attention_dropout: The dropout rate to use for the attention layers within
the transformer layers. the transformer layers.
pool_stride: Pooling stride to compress the sequence length. pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
unpool_length: Leading n tokens to be skipped from pooling. unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the output_range: The sequence output range, [0, output_range), by slicing the
...@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
name='pooler_transform') name='pooler_transform')
self._att_input_pool_layer = tf.keras.layers.MaxPooling1D( if isinstance(pool_stride, int):
pool_size=pool_stride, # TODO(b/197133196): Pooling layer can be shared.
strides=pool_stride, pool_strides = [pool_stride] * num_layers
padding='same', else:
name='att_input_pool_layer') if len(pool_stride) != num_layers:
self._pool_stride = pool_stride raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride
self._att_input_pool_layers = []
for layer_pool_stride in pool_strides:
att_input_pool_layer = tf.keras.layers.MaxPooling1D(
pool_size=layer_pool_stride,
strides=layer_pool_stride,
padding='same',
name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer)
self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._config = { self._config = {
...@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask = _pool_and_concat( attention_mask = _pool_and_concat(
attention_mask, attention_mask,
unpool_length=self._unpool_length, unpool_length=self._unpool_length,
stride=self._pool_stride, strides=self._pool_strides[0],
axes=[1]) axes=[1])
for layer in self._transformer_layers: for i, layer in enumerate(self._transformer_layers):
# Pools layer for compressing the query length. # Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layer(x[:, self._unpool_length:, :]) pooled_inputs = self._att_input_pool_layers[i](
x[:, self._unpool_length:, :])
query_inputs = tf.concat( query_inputs = tf.concat(
values=(tf.cast( values=(tf.cast(
x[:, :self._unpool_length, :], x[:, :self._unpool_length, :],
...@@ -262,11 +282,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -262,11 +282,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axis=1) axis=1)
x = layer([query_inputs, x, attention_mask]) x = layer([query_inputs, x, attention_mask])
# Pools the corresponding attention_mask. # Pools the corresponding attention_mask.
attention_mask = _pool_and_concat( if i < len(self._transformer_layers) - 1:
attention_mask, attention_mask = _pool_and_concat(
unpool_length=self._unpool_length, attention_mask,
stride=self._pool_stride, unpool_length=self._unpool_length,
axes=[1, 2]) strides=[self._pool_strides[i+1], self._pool_strides[i]],
axes=[1, 2])
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
......
...@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_invalid_stride_and_num_layers(self):
hidden_size = 32
num_layers = 3
pool_stride = [2, 2]
unpool_length = 1
with self.assertRaisesRegex(ValueError,
"pool_stride and num_layers are not equal"):
_ = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
unpool_length=unpool_length)
@parameterized.named_parameters( @parameterized.named_parameters(
("no_stride_no_unpool", 1, 0), ("no_stride_no_unpool", 1, 0),
("stride_list_with_unpool", [2, 3, 4], 1),
("large_stride_with_unpool", 3, 1), ("large_stride_with_unpool", 3, 1),
("large_stride_with_large_unpool", 5, 10), ("large_stride_with_large_unpool", 5, 10),
("no_stride_with_unpool", 1, 1), ("no_stride_with_unpool", 1, 1),
...@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape = [None, sequence_length, hidden_size] expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, num_layers) self.assertLen(all_encoder_outputs, num_layers)
for data in all_encoder_outputs: if isinstance(pool_stride, int):
expected_data_shape[1] = unpool_length + (expected_data_shape[1] + pool_stride = [pool_stride] * num_layers
pool_stride - 1 - for layer_pool_stride, data in zip(pool_stride, all_encoder_outputs):
unpool_length) // pool_stride expected_data_shape[1] = unpool_length + (
print("shapes:", expected_data_shape, data.shape.as_list()) expected_data_shape[1] + layer_pool_stride - 1 -
unpool_length) // layer_pool_stride
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment