Commit d182e423 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 456159384
parent a471eb3b
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
from typing import Dict from typing import Dict
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import rezero_transformer
from official.nlp.modeling.layers import self_attention_mask from official.nlp.modeling.layers import self_attention_mask
from official.nlp.modeling.layers import transformer_encoder_block from official.nlp.modeling.layers import transformer_encoder_block
from official.nlp.modeling.layers import transformer_scaffold
def _packing_mask(segment_id, source_segment_id, dtype=tf.float32): def _packing_mask(segment_id, source_segment_id, dtype=tf.float32):
...@@ -142,3 +144,114 @@ class StridedTransformerEncoderBlock( ...@@ -142,3 +144,114 @@ class StridedTransformerEncoderBlock(
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output) return self._output_layer_norm(layer_output + attention_output)
@tf.keras.utils.register_keras_serializable(package='Text')
class StridedReZeroTransformer(rezero_transformer.ReZeroTransformer):
"""ReZeroTransformer for packing optimization to stride over inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._output_range is not None:
raise ValueError(f'{self.__class__} does not '
'support `output_range` argument.')
def call(self, inputs, stride: tf.Tensor):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError(f'Unexpected inputs to {self.__class__} with '
f'length at {len(inputs)}.')
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
target_tensor = input_tensor[:, ::stride, :]
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._inner_activation_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
layer_output = attention_output + tf.cast(self._rezero_a_ffn * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
return layer_output
@tf.keras.utils.register_keras_serializable(package='Text')
class StridedTransformerScaffold(transformer_scaffold.TransformerScaffold):
"""TransformerScaffold for packing optimization to stride over inputs."""
def call(self, inputs, stride: tf.Tensor, training=None):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
if self._norm_first:
source_tensor = input_tensor[:, ::stride, :]
input_tensor = self._attention_layer_norm(input_tensor, training=training)
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
target_tensor = input_tensor[:, ::stride, :]
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask,
training=training)
attention_output = self._attention_dropout(attention_output,
training=training)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output,
training=training)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output,
training=training)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output, training=training)
layer_output = self._output_dropout(layer_output, training=training)
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output,
training=training)
else:
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output,
training=training)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output,
training=training)
return layer_output
...@@ -37,8 +37,29 @@ class PackOptimizationTest(tf.test.TestCase): ...@@ -37,8 +37,29 @@ class PackOptimizationTest(tf.test.TestCase):
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32) attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
transformer = pack_optimization.StridedTransformerEncoderBlock( transformer = pack_optimization.StridedTransformerEncoderBlock(
num_attention_heads=2, inner_dim=4, inner_activation="relu") num_attention_heads=2, inner_dim=4, inner_activation="relu")
_ = transformer([inputs, attention_mask], outputs = transformer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32)) stride=tf.constant(2, dtype=tf.int32))
self.assertEqual(outputs.shape, (2, 2, 8))
def test_strided_rezero_transformer(self):
inputs = tf.zeros((2, 4, 8), dtype=tf.float32)
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
transformer = pack_optimization.StridedReZeroTransformer(
num_attention_heads=2, inner_dim=4, inner_activation="relu")
outputs = transformer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32))
self.assertEqual(outputs.shape, (2, 2, 8))
def test_strided_scaffold(self):
inputs = tf.zeros((2, 4, 8), dtype=tf.float32)
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
test_layer = pack_optimization.StridedTransformerScaffold(
num_attention_heads=2,
inner_dim=128,
inner_activation="relu")
outputs = test_layer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32))
self.assertEqual(outputs.shape, (2, 2, 8))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment