Commit 0bd679b0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 325736862
parent 35492c3d
......@@ -22,6 +22,23 @@ from __future__ import print_function
import tensorflow as tf
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedSoftmax(tf.keras.layers.Layer):
"""Performs a softmax with optional masking on a tensor.
......@@ -50,9 +67,9 @@ class MaskedSoftmax(tf.keras.layers.Layer):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(mask, scores.dtype)) * -10000.0
# positions we want to attend and -1.e9 for masked positions.
adder = (1.0 - tf.cast(mask, scores.dtype)) * _large_compatible_negative(
scores.dtype)
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
scores += adder
......
......@@ -127,8 +127,9 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
def test_layer_output_range(self, _):
# XLA has an obvious numeric issue in this test case.
test_layer = transformer.Transformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -144,7 +145,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = transformer_cls(
new_layer = transformer.Transformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment