Commit 0aba91fc authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 400837715
parent 06b2d7d7
...@@ -16,20 +16,21 @@ ...@@ -16,20 +16,21 @@
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Union, Sequence from typing import Union, Sequence
from absl import logging from absl import logging
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp import keras_nlp from official.nlp import keras_nlp
def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int], def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int], int],
axes: Union[Sequence[int], int]): axes: Union[Sequence[int], int]):
"""Pools the data along a given axis with stride. """Pools the mask along a given axis with stride.
It also skips first unpool_length elements. It also skips first unpool_length elements.
Args: Args:
data: Tensor to be pooled. mask: Tensor to be pooled.
unpool_length: Leading elements to be skipped. unpool_length: Leading elements to be skipped.
strides: Strides for the given axes. strides: Strides for the given axes.
axes: Axes to pool the Tensor. axes: Axes to pool the Tensor.
...@@ -45,18 +46,21 @@ def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int], ...@@ -45,18 +46,21 @@ def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
else: else:
if len(strides) != len(axes): if len(strides) != len(axes):
raise ValueError('The lengths of strides and axes need to match.') raise ValueError('The lengths of strides and axes need to match.')
# Bypass no pooling cases.
if np.all(np.array(strides) == 1):
return mask
for axis, stride in zip(axes, strides): for axis, stride in zip(axes, strides):
# Skips first `unpool_length` tokens. # Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)] unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = data[unpool_tensor_shape] unpool_tensor = mask[unpool_tensor_shape]
# Pools the second half. # Pools the second half.
pool_tensor_shape = [slice(None)] * axis + [ pool_tensor_shape = [slice(None)] * axis + [
slice(unpool_length, None, stride) slice(unpool_length, None, stride)
] ]
pool_tensor = data[pool_tensor_shape] pool_tensor = mask[pool_tensor_shape]
data = tf.concat((unpool_tensor, pool_tensor), axis=axis) mask = tf.concat((unpool_tensor, pool_tensor), axis=axis)
return data return mask
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
...@@ -272,6 +276,10 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -272,6 +276,10 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
strides=self._pool_strides[0], strides=self._pool_strides[0],
axes=[1]) axes=[1])
for i, layer in enumerate(self._transformer_layers): for i, layer in enumerate(self._transformer_layers):
# Bypass no pooling cases.
if self._pool_strides[i] == 1:
x = layer([x, x, attention_mask])
else:
# Pools layer for compressing the query length. # Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layers[i]( pooled_inputs = self._att_input_pool_layers[i](
x[:, self._unpool_length:, :]) x[:, self._unpool_length:, :])
...@@ -286,7 +294,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -286,7 +294,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask = _pool_and_concat( attention_mask = _pool_and_concat(
attention_mask, attention_mask,
unpool_length=self._unpool_length, unpool_length=self._unpool_length,
strides=[self._pool_strides[i+1], self._pool_strides[i]], strides=[self._pool_strides[i + 1], self._pool_strides[i]],
axes=[1, 2]) axes=[1, 2])
encoder_outputs.append(x) encoder_outputs.append(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment