Commit 0aba91fc authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 400837715
parent 06b2d7d7
......@@ -16,20 +16,21 @@
# pylint: disable=g-classes-have-attributes
from typing import Union, Sequence
from absl import logging
import numpy as np
import tensorflow as tf
from official.nlp import keras_nlp
def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int],
axes: Union[Sequence[int], int]):
"""Pools the data along a given axis with stride.
"""Pools the mask along a given axis with stride.
It also skips first unpool_length elements.
Args:
data: Tensor to be pooled.
mask: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
strides: Strides for the given axes.
axes: Axes to pool the Tensor.
......@@ -45,18 +46,21 @@ def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
else:
if len(strides) != len(axes):
raise ValueError('The lengths of strides and axes need to match.')
# Bypass no pooling cases.
if np.all(np.array(strides) == 1):
return mask
for axis, stride in zip(axes, strides):
# Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = data[unpool_tensor_shape]
unpool_tensor = mask[unpool_tensor_shape]
# Pools the second half.
pool_tensor_shape = [slice(None)] * axis + [
slice(unpool_length, None, stride)
]
pool_tensor = data[pool_tensor_shape]
data = tf.concat((unpool_tensor, pool_tensor), axis=axis)
return data
pool_tensor = mask[pool_tensor_shape]
mask = tf.concat((unpool_tensor, pool_tensor), axis=axis)
return mask
@tf.keras.utils.register_keras_serializable(package='Text')
......@@ -272,6 +276,10 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
strides=self._pool_strides[0],
axes=[1])
for i, layer in enumerate(self._transformer_layers):
# Bypass no pooling cases.
if self._pool_strides[i] == 1:
x = layer([x, x, attention_mask])
else:
# Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layers[i](
x[:, self._unpool_length:, :])
......@@ -286,7 +294,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
strides=[self._pool_strides[i+1], self._pool_strides[i]],
strides=[self._pool_strides[i + 1], self._pool_strides[i]],
axes=[1, 2])
encoder_outputs.append(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment