Commit a3be7365 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Add a CompiledTransformer layer, which is compiled with XLA.

PiperOrigin-RevId: 304222530
parent eb6e819d
...@@ -23,6 +23,7 @@ import tensorflow as tf ...@@ -23,6 +23,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers.util import tf_function_if_eager
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -219,3 +220,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -219,3 +220,10 @@ class Transformer(tf.keras.layers.Layer):
layer_output = self._output_layer_norm(layer_output + attention_output) layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output return layer_output
class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True)
def call(self, inputs):
return super(CompiledTransformer, self).call(inputs)
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -28,14 +29,16 @@ from official.nlp.modeling.layers import transformer ...@@ -28,14 +29,16 @@ from official.nlp.modeling.layers import transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@parameterized.parameters(transformer.Transformer,
transformer.CompiledTransformer)
class TransformerLayerTest(keras_parameterized.TestCase): class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(TransformerLayerTest, self).tearDown() super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_creation(self): def test_layer_creation(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu')
...@@ -47,8 +50,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -47,8 +50,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input. # The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list()) self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self): def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu')
...@@ -62,8 +65,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -62,8 +65,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input. # The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list()) self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self): def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu')
...@@ -76,8 +79,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -76,8 +79,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'): with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor]) _ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self): def test_layer_invocation(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu')
...@@ -97,8 +100,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -97,8 +100,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
_ = model.predict(input_data) _ = model.predict(input_data)
def test_layer_invocation_with_mask(self): def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu')
...@@ -124,9 +127,9 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -124,9 +127,9 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data]) _ = model.predict([input_data, mask_data])
def test_layer_invocation_with_float16_dtype(self): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu') intermediate_activation='relu')
...@@ -152,8 +155,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,8 +155,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data]) _ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self): def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu', intermediate_activation='relu',
...@@ -166,8 +169,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -166,8 +169,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input. # The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list()) self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self): def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer.Transformer( test_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
intermediate_activation='relu', intermediate_activation='relu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment