Commit 393c1399 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Enable XLA compilation using `@tf.function(experimental_compile=True) for transformer layer.

To debug the tf.function this API can be used: https://www.tensorflow.org/api_docs/python/tf/config/experimental_run_functions_eagerly

PiperOrigin-RevId: 296458870
parent 867f0c47
......@@ -193,6 +193,7 @@ class Transformer(tf.keras.layers.Layer):
base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@tf.function(experimental_compile=True)
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
......@@ -204,6 +205,7 @@ class Transformer(tf.keras.layers.Layer):
if attention_mask is not None:
attention_inputs.append(attention_mask)
with tf.name_scope(self.name):
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output)
......@@ -215,7 +217,8 @@ class Transformer(tf.keras.layers.Layer):
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. cast layer_output to fp32 for the subsequent add.
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment