Commit 700d29e9 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move `output_range` argument to call(). After evaluating several use cases,...

Move `output_range` argument to call(). After evaluating several use cases, `output_range` is better to be specified at `call` time which allows us the train with full sequence and serve with certain tokens.

PiperOrigin-RevId: 462885854
parent 226f9419
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
from typing import Any, Optional
from absl import logging
import tensorflow as tf
......@@ -129,6 +129,11 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
# Deprecation warning.
if output_range is not None:
logging.warning("`output_range` is avaliable as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
......@@ -258,7 +263,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(TransformerEncoderBlock, self).build(input_shape)
super().build(input_shape)
def get_config(self):
config = {
......@@ -310,10 +315,10 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"diff_q_kv_att_layer_norm":
self._diff_q_kv_att_layer_norm,
}
base_config = super(TransformerEncoderBlock, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
"""Transformer self-attention encoder block call.
Args:
......@@ -324,6 +329,10 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced. If you
would like to have no change to the model training, it is better to only
set the `output_range` for serving.
Returns:
An output tensor with the same dimensions as input/query tensor.
......@@ -340,15 +349,16 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
output_range = output_range or self._output_range
if output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
source_tensor = input_tensor[:, 0:output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm_kv(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
target_tensor = input_tensor[:, 0:output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
......
......@@ -125,6 +125,9 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048,
......@@ -179,6 +182,9 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment