Commit 96a8d744 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446242527
parent c7734283
...@@ -343,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -343,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# TODO(b/203665205): unpool_length should be implemented. # TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0: if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.') raise ValueError('unpool_length is not supported by truncated_avg now.')
# Compute the attention masks and pooling transforms.
self._pooling_transforms = _create_truncated_avg_transforms(
max_sequence_length, pool_strides)
else: else:
raise ValueError('pool_type not supported.') raise ValueError('pool_type not supported.')
...@@ -359,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -359,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
name='att_input_pool_layer') name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer) self._att_input_pool_layers.append(att_input_pool_layer)
self._max_sequence_length = max_sequence_length
self._pool_strides = pool_strides # This is a list here. self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._pool_type = pool_type self._pool_type = pool_type
...@@ -489,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -489,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axes=[1, 2]) axes=[1, 2])
encoder_outputs.append(x) encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG: elif self._pool_type == _TRUNCATED_AVG:
# Compute the attention masks and pooling transforms.
# Note we do not compute this in __init__ due to inference converter issue
# b/215659399.
pooling_transforms = _create_truncated_avg_transforms(
self._max_sequence_length, self._pool_strides)
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides, attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
self._pooling_transforms) pooling_transforms)
for i, layer in enumerate(self._transformer_layers): for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i] attention_mask = attention_masks[i]
# Bypass no pooling cases. # Bypass no pooling cases.
...@@ -501,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -501,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'BFD,FT->BTD', 'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype() tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
), # extra casting for faster mixed computation. ), # extra casting for faster mixed computation.
self._pooling_transforms[i]) pooling_transforms[i])
query_inputs = tf.concat( query_inputs = tf.concat(
values=(tf.cast( values=(tf.cast(
x[:, :self._unpool_length, :], x[:, :self._unpool_length, :],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment