Commit e0c2c302 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Do not recreate tf.function on different calls to the same layer

PiperOrigin-RevId: 297366158
parent 01dbd5bf
......@@ -195,11 +195,13 @@ class Transformer(tf.keras.layers.Layer):
def call(self, inputs):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
call_impl = self.call_impl
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions"
) or tf.compat.v1.executing_eagerly_outside_functions():
call_impl = tf.function(experimental_compile=True)(call_impl)
return call_impl(inputs)
if not hasattr(self, "_call_impl"):
self._call_impl = self.call_impl
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions"
) or tf.compat.v1.executing_eagerly_outside_functions():
self._call_impl = tf.function(experimental_compile=True)(
self._call_impl)
return self._call_impl(inputs)
def call_impl(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment