Commit 56c8a503 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 471177782
parent b0cf164b
......@@ -14,7 +14,9 @@
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes
from typing import Optional
from absl import logging
import gin
import tensorflow as tf
......@@ -80,6 +82,11 @@ class ReZeroTransformer(tf.keras.layers.Layer):
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
# Deprecation warning.
if output_range is not None:
logging.warning("`output_range` is avaliable as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
......@@ -237,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
if not self._share_rezero:
self._rezero_a_ffn.assign(0.)
def call(self, inputs):
def call(self, inputs, output_range: Optional[tf.Tensor] = None) -> tf.Tensor:
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
......@@ -250,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
if output_range is None:
output_range = self._output_range
if output_range:
target_tensor = input_tensor[:, 0:output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:output_range, :]
else:
target_tensor = input_tensor
......@@ -270,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._inner_activation_layer(
intermediate_output)
intermediate_output = self._inner_activation_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
......
......@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_separate_qkv(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=2,
......
......@@ -349,7 +349,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
output_range = output_range or self._output_range
if output_range is None:
output_range = self._output_range
if output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:output_range, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment