Commit 5292d16e authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

#Funnel fix `mixed_precision=None` error when mixed with TF1 code.

PiperOrigin-RevId: 407253800
parent e387ed65
...@@ -27,6 +27,13 @@ _AVG = 'avg' ...@@ -27,6 +27,13 @@ _AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg' _TRUNCATED_AVG = 'truncated_avg'
def _get_policy_dtype():
try:
return tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
except AttributeError: # tf1 has no attribute 'global_policy'
return tf.float32
def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int], def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int], int],
axes: Union[Sequence[int], int]): axes: Union[Sequence[int], int]):
...@@ -105,9 +112,7 @@ def _create_truncated_avg_transforms(seq_length: int, ...@@ -105,9 +112,7 @@ def _create_truncated_avg_transforms(seq_length: int,
transform = [[1.0 if (i // pfac) == j else 0.0 transform = [[1.0 if (i // pfac) == j else 0.0
for j in range(psl)] for j in range(psl)]
for i in range(sl)] for i in range(sl)]
transform = tf.constant( transform = tf.constant(transform, dtype=_get_policy_dtype())
transform,
dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
pooling_transforms.append(transform / pool_stride) pooling_transforms.append(transform / pool_stride)
seq_length = pooled_seq_length seq_length = pooled_seq_length
...@@ -125,7 +130,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor, ...@@ -125,7 +130,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
Args: Args:
input_mask: Tensor of shape [batch_size, seq_length]. input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer. pool_strides: Sequence of pooling strides for each layer.
transforms: Sequnce of off-diagonal matrices filling with 0.0 and transforms: Sequence of off-diagonal matrices filling with 0.0 and
1/pool_stride. 1/pool_stride.
Returns: Returns:
...@@ -138,8 +143,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor, ...@@ -138,8 +143,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
attention_masks = [] attention_masks = []
seq_length = tf.shape(input_mask)[-1] seq_length = tf.shape(input_mask)[-1]
layer_mask = tf.cast( layer_mask = tf.cast(input_mask, dtype=_get_policy_dtype())
input_mask, dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
for pool_stride, transform in zip(pool_strides, transforms): for pool_stride, transform in zip(pool_strides, transforms):
if pool_stride == 1: if pool_stride == 1:
attention_masks.append(create_2d_mask(seq_length, layer_mask)) attention_masks.append(create_2d_mask(seq_length, layer_mask))
...@@ -423,8 +427,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -423,8 +427,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
else: else:
pooled_inputs = tf.einsum( pooled_inputs = tf.einsum(
'BFD,FT->BTD', 'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :], tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
tf.keras.mixed_precision.global_policy().compute_dtype
), # extra casting for faster mixed computation. ), # extra casting for faster mixed computation.
self._pooling_transforms[i]) self._pooling_transforms[i])
query_inputs = tf.concat( query_inputs = tf.concat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment