Commit 96a8d744 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446242527
parent c7734283
......@@ -343,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.')
# Compute the attention masks and pooling transforms.
self._pooling_transforms = _create_truncated_avg_transforms(
max_sequence_length, pool_strides)
else:
raise ValueError('pool_type not supported.')
......@@ -359,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer)
self._max_sequence_length = max_sequence_length
self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length
self._pool_type = pool_type
......@@ -489,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axes=[1, 2])
encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG:
# Compute the attention masks and pooling transforms.
# Note we do not compute this in __init__ due to inference converter issue
# b/215659399.
pooling_transforms = _create_truncated_avg_transforms(
self._max_sequence_length, self._pool_strides)
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
self._pooling_transforms)
pooling_transforms)
for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i]
# Bypass no pooling cases.
......@@ -501,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
), # extra casting for faster mixed computation.
self._pooling_transforms[i])
pooling_transforms[i])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment