Commit ef61bae9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Update nlp.modeling.layers.ReZeroTransformer, to have the same interface with...

Update nlp.modeling.layers.ReZeroTransformer, to have the same interface with nlp.modeling.layers.Transformer

PiperOrigin-RevId: 311937563
parent e5c9661a
......@@ -42,6 +42,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -58,6 +60,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -74,6 +77,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
......@@ -176,6 +180,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"use_layer_norm":
self._use_layer_norm,
"kernel_initializer":
......@@ -205,11 +211,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = input_tensor + self._rezero_a * attention_output
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
......
......@@ -101,6 +101,33 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
self.assertAllClose(input_data_normed, output_data)
def test_layer_output_range(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment