Commit 350f4854 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398286699
parent 9d5a1a76
......@@ -14,15 +14,16 @@
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
from typing import Union, Collection
from typing import Union, Sequence
from absl import logging
import tensorflow as tf
from official.nlp import keras_nlp
def _pool_and_concat(data, unpool_length: int, stride: int,
axes: Union[Collection[int], int]):
def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
int],
axes: Union[Sequence[int], int]):
"""Pools the data along a given axis with stride.
It also skips first unpool_length elements.
......@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
Args:
data: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
stride: Stride for the given axis.
strides: Strides for the given axes.
axes: Axes to pool the Tensor.
Returns:
......@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
# Wraps the axes as a list.
if isinstance(axes, int):
axes = [axes]
if isinstance(strides, int):
strides = [strides] * len(axes)
else:
if len(strides) != len(axes):
raise ValueError('The lengths of strides and axes need to match.')
for axis in axes:
for axis, stride in zip(axes, strides):
# Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = data[unpool_tensor_shape]
......@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
pool_stride: Pooling stride to compress the sequence length.
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
......@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._att_input_pool_layer = tf.keras.layers.MaxPooling1D(
pool_size=pool_stride,
strides=pool_stride,
if isinstance(pool_stride, int):
# TODO(b/197133196): Pooling layer can be shared.
pool_strides = [pool_stride] * num_layers
else:
if len(pool_stride) != num_layers:
raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride
self._att_input_pool_layers = []
for layer_pool_stride in pool_strides:
att_input_pool_layer = tf.keras.layers.MaxPooling1D(
pool_size=layer_pool_stride,
strides=layer_pool_stride,
padding='same',
name='att_input_pool_layer')
self._pool_stride = pool_stride
self._att_input_pool_layers.append(att_input_pool_layer)
self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length
self._config = {
......@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
stride=self._pool_stride,
strides=self._pool_strides[0],
axes=[1])
for layer in self._transformer_layers:
for i, layer in enumerate(self._transformer_layers):
# Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layer(x[:, self._unpool_length:, :])
pooled_inputs = self._att_input_pool_layers[i](
x[:, self._unpool_length:, :])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
......@@ -262,10 +282,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axis=1)
x = layer([query_inputs, x, attention_mask])
# Pools the corresponding attention_mask.
if i < len(self._transformer_layers) - 1:
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
stride=self._pool_stride,
strides=[self._pool_strides[i+1], self._pool_strides[i]],
axes=[1, 2])
encoder_outputs.append(x)
......
......@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_invalid_stride_and_num_layers(self):
hidden_size = 32
num_layers = 3
pool_stride = [2, 2]
unpool_length = 1
with self.assertRaisesRegex(ValueError,
"pool_stride and num_layers are not equal"):
_ = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
unpool_length=unpool_length)
@parameterized.named_parameters(
("no_stride_no_unpool", 1, 0),
("stride_list_with_unpool", [2, 3, 4], 1),
("large_stride_with_unpool", 3, 1),
("large_stride_with_large_unpool", 5, 10),
("no_stride_with_unpool", 1, 1),
......@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, num_layers)
for data in all_encoder_outputs:
expected_data_shape[1] = unpool_length + (expected_data_shape[1] +
pool_stride - 1 -
unpool_length) // pool_stride
print("shapes:", expected_data_shape, data.shape.as_list())
if isinstance(pool_stride, int):
pool_stride = [pool_stride] * num_layers
for layer_pool_stride, data in zip(pool_stride, all_encoder_outputs):
expected_data_shape[1] = unpool_length + (
expected_data_shape[1] + layer_pool_stride - 1 -
unpool_length) // layer_pool_stride
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment