Commit bae103c5 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 404920138
parent ae82b280
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Funnel Transformer network.""" """Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Union, Sequence from typing import Union, Sequence
from absl import logging from absl import logging
import numpy as np import numpy as np
...@@ -21,6 +22,10 @@ import tensorflow as tf ...@@ -21,6 +22,10 @@ import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
_MAX = 'max'
_AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg'
def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int], def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int], int],
...@@ -63,6 +68,94 @@ def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int], ...@@ -63,6 +68,94 @@ def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
return mask return mask
def _create_truncated_avg_transforms(seq_length: int,
pool_strides: Sequence[int]):
"""Computes pooling transforms.
The pooling_transform is of shape [seq_length,
seq_length//pool_stride] and
pooling_transform[i,j] = 1.0/pool_stride if i//pool_stride == j
0.0 otherwise.
It's in essense average pooling but truncate the final window if it
seq_length % pool_stride != 0.
For seq_length==6 and pool_stride==2, it is
[[ 0.5, 0.0, 0.0 ],
[ 0.5, 0.0, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.0, 0.5 ],
[ 0.0, 0.0, 0.5 ]]
Args:
seq_length: int, sequence length.
pool_strides: Sequence of pooling strides for each layer.
Returns:
pooling_transforms: Sequence of pooling transforms (Tensors) for each layer.
"""
pooling_transforms = []
for pool_stride in pool_strides:
if pool_stride == 1:
pooling_transforms.append(None)
else:
pooled_seq_length = seq_length // pool_stride
pfac, sl, psl = pool_stride, seq_length, pooled_seq_length
transform = [[1.0 if (i // pfac) == j else 0.0
for j in range(psl)]
for i in range(sl)]
transform = tf.constant(
transform,
dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
pooling_transforms.append(transform / pool_stride)
seq_length = pooled_seq_length
return pooling_transforms
def _create_truncated_avg_masks(input_mask: tf.Tensor,
pool_strides: Sequence[int],
transforms: Sequence[tf.Tensor]):
"""Computes attention masks.
For [1,1,1,0,0]
Args:
input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer.
transforms: Sequnce of off-diagonal matrices filling with 0.0 and
1/pool_stride.
Returns:
attention_masks: Sequence of attention masks for each layer.
"""
def create_2d_mask(from_length, mask):
return tf.einsum('F,BT->BFT', tf.ones([from_length], dtype=mask.dtype),
mask)
attention_masks = []
seq_length = tf.shape(input_mask)[-1]
layer_mask = tf.cast(
input_mask, dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
for pool_stride, transform in zip(pool_strides, transforms):
if pool_stride == 1:
attention_masks.append(create_2d_mask(seq_length, layer_mask))
else:
pooled_seq_length = seq_length // pool_stride
attention_masks.append(create_2d_mask(pooled_seq_length, layer_mask))
layer_mask = tf.cast(
tf.einsum('BF,FT->BT', layer_mask, transform) > 0.0,
dtype=layer_mask.dtype)
seq_length = pooled_seq_length
del seq_length
return attention_masks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class FunnelTransformerEncoder(tf.keras.layers.Layer): class FunnelTransformerEncoder(tf.keras.layers.Layer):
"""Funnel Transformer-based encoder network. """Funnel Transformer-based encoder network.
...@@ -90,7 +183,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -90,7 +183,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout. dropout.
attention_dropout: The dropout rate to use for the attention layers within attention_dropout: The dropout rate to use for the attention layers within
the transformer layers. the transformer layers.
pool_type: Pooling type. Choose from ['max', 'avg']. pool_type: Pooling type. Choose from ['max', 'avg', 'truncated_avg'].
pool_stride: An int or a list of ints. Pooling stride(s) to compress the pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size. sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers. If set to list, the number of elements needs to match num_layers.
...@@ -124,7 +217,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -124,7 +217,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True), inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
output_dropout=0.1, output_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
pool_type='max', pool_type=_MAX,
pool_stride=2, pool_stride=2,
unpool_length=0, unpool_length=0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
...@@ -207,12 +300,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -207,12 +300,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
raise ValueError('Lengths of pool_stride and num_layers are not equal.') raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride pool_strides = pool_stride
# TODO(crickwu): explore tf.keras.layers.serialize method. # TODO(crickwu): explore tf.keras.layers.serialize method.
if pool_type == 'max': if pool_type == _MAX:
pool_cls = tf.keras.layers.MaxPooling1D pool_cls = tf.keras.layers.MaxPooling1D
elif pool_type == 'avg': elif pool_type == _AVG:
pool_cls = tf.keras.layers.AveragePooling1D pool_cls = tf.keras.layers.AveragePooling1D
elif pool_type == _TRUNCATED_AVG:
# TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.')
# Compute the attention masks and pooling transforms.
self._pooling_transforms = _create_truncated_avg_transforms(
max_sequence_length, pool_strides)
else: else:
raise ValueError('pool_type not supported.') raise ValueError('pool_type not supported.')
if pool_type in (_MAX, _AVG):
self._att_input_pool_layers = [] self._att_input_pool_layers = []
for layer_pool_stride in pool_strides: for layer_pool_stride in pool_strides:
att_input_pool_layer = pool_cls( att_input_pool_layer = pool_cls(
...@@ -224,6 +326,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -224,6 +326,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pool_strides = pool_strides # This is a list here. self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._pool_type = pool_type
self._config = { self._config = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
...@@ -280,11 +383,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -280,11 +383,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
encoder_outputs = [] encoder_outputs = []
x = embeddings x = embeddings
# TODO(b/195972228): attention_mask can be co-generated with pooling. # TODO(b/195972228): attention_mask can be co-generated with pooling.
if self._pool_type in (_MAX, _AVG):
attention_mask = _pool_and_concat( attention_mask = _pool_and_concat(
attention_mask, attention_mask,
unpool_length=self._unpool_length, unpool_length=self._unpool_length,
strides=self._pool_strides[0], strides=self._pool_strides[0],
axes=[1]) axes=[1])
for i, layer in enumerate(self._transformer_layers): for i, layer in enumerate(self._transformer_layers):
# Bypass no pooling cases. # Bypass no pooling cases.
if self._pool_strides[i] == 1: if self._pool_strides[i] == 1:
...@@ -307,12 +412,36 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -307,12 +412,36 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
strides=[self._pool_strides[i + 1], self._pool_strides[i]], strides=[self._pool_strides[i + 1], self._pool_strides[i]],
axes=[1, 2]) axes=[1, 2])
encoder_outputs.append(x) encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG:
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
self._pooling_transforms)
for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i]
# Bypass no pooling cases.
if self._pool_strides[i] == 1:
x = layer([x, x, attention_mask])
else:
pooled_inputs = tf.einsum(
'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :],
tf.keras.mixed_precision.global_policy().compute_dtype
), # extra casting for faster mixed computation.
self._pooling_transforms[i])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs),
axis=1)
x = layer([query_inputs, x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :] first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor) pooled_output = self._pooler_layer(first_token_tensor)
return dict( return dict(
word_embeddings=word_embeddings,
embedding_output=embeddings,
sequence_output=encoder_outputs[-1], sequence_output=encoder_outputs[-1],
pooled_output=pooled_output, pooled_output=pooled_output,
encoder_outputs=encoder_outputs) encoder_outputs=encoder_outputs)
......
...@@ -38,6 +38,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -38,6 +38,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
tf.keras.mixed_precision.set_global_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters( @parameterized.named_parameters(
("mix_truncated_avg", "mixed_float16", tf.float16, "truncated_avg"),
("float32_truncated_avg", "float32", tf.float32, "truncated_avg"),
("mix_max", "mixed_float16", tf.float16, "max"), ("mix_max", "mixed_float16", tf.float16, "max"),
("float32_max", "float32", tf.float32, "max"), ("float32_max", "float32", tf.float32, "max"),
("mix_avg", "mixed_float16", tf.float16, "avg"), ("mix_avg", "mixed_float16", tf.float16, "avg"),
...@@ -57,6 +59,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -57,6 +59,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
num_layers=num_layers, num_layers=num_layers,
pool_stride=pool_stride, pool_stride=pool_stride,
pool_type=pool_type, pool_type=pool_type,
max_sequence_length=sequence_length,
unpool_length=0) unpool_length=0)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -71,8 +74,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -71,8 +74,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense) self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
# Stride=2 compresses sequence length to half the size at each layer. # Stride=2 compresses sequence length to half the size at each layer.
# This configuration gives each layer of seq length: 21->11->6->3. # For pool_type = max or avg,
# this configuration gives each layer of seq length: 21->11->6->3.
# For pool_type = truncated_avg,
# seq length: 21->10->5->2.
if pool_type in ["max", "avg"]:
expected_data_shape = [None, 3, hidden_size] expected_data_shape = [None, 3, hidden_size]
else:
expected_data_shape = [None, 2, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment