Commit a3be7365 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Add a CompiledTransformer layer, which is compiled with XLA.

PiperOrigin-RevId: 304222530
parent eb6e819d
......@@ -23,6 +23,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers.util import tf_function_if_eager
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -219,3 +220,10 @@ class Transformer(tf.keras.layers.Layer):
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True)
def call(self, inputs):
return super(CompiledTransformer, self).call(inputs)
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -28,14 +29,16 @@ from official.nlp.modeling.layers import transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
@parameterized.parameters(transformer.Transformer,
transformer.CompiledTransformer)
class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_creation(self):
test_layer = transformer.Transformer(
def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -47,8 +50,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_mask(self):
test_layer = transformer.Transformer(
def test_layer_creation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -62,8 +65,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self):
test_layer = transformer.Transformer(
def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -76,8 +79,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self):
test_layer = transformer.Transformer(
def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -97,8 +100,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
(batch_size, sequence_length, width))
_ = model.predict(input_data)
def test_layer_invocation_with_mask(self):
test_layer = transformer.Transformer(
def test_layer_invocation_with_mask(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -124,9 +127,9 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_layer_invocation_with_float16_dtype(self):
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer.Transformer(
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
......@@ -152,8 +155,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_transform_with_initializer(self):
test_layer = transformer.Transformer(
def test_transform_with_initializer(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
......@@ -166,8 +169,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self):
test_layer = transformer.Transformer(
def test_dynamic_layer_sequence(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment