Commit 5292d16e authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

#Funnel fix `mixed_precision=None` error when mixed with TF1 code.

PiperOrigin-RevId: 407253800
parent e387ed65
......@@ -27,6 +27,13 @@ _AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg'
def _get_policy_dtype():
try:
return tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
except AttributeError: # tf1 has no attribute 'global_policy'
return tf.float32
def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int],
axes: Union[Sequence[int], int]):
......@@ -105,9 +112,7 @@ def _create_truncated_avg_transforms(seq_length: int,
transform = [[1.0 if (i // pfac) == j else 0.0
for j in range(psl)]
for i in range(sl)]
transform = tf.constant(
transform,
dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
transform = tf.constant(transform, dtype=_get_policy_dtype())
pooling_transforms.append(transform / pool_stride)
seq_length = pooled_seq_length
......@@ -125,7 +130,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
Args:
input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer.
transforms: Sequnce of off-diagonal matrices filling with 0.0 and
transforms: Sequence of off-diagonal matrices filling with 0.0 and
1/pool_stride.
Returns:
......@@ -138,8 +143,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
attention_masks = []
seq_length = tf.shape(input_mask)[-1]
layer_mask = tf.cast(
input_mask, dtype=tf.keras.mixed_precision.global_policy().compute_dtype)
layer_mask = tf.cast(input_mask, dtype=_get_policy_dtype())
for pool_stride, transform in zip(pool_strides, transforms):
if pool_stride == 1:
attention_masks.append(create_2d_mask(seq_length, layer_mask))
......@@ -423,8 +427,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
else:
pooled_inputs = tf.einsum(
'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :],
tf.keras.mixed_precision.global_policy().compute_dtype
tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
), # extra casting for faster mixed computation.
self._pooling_transforms[i])
query_inputs = tf.concat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment