"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "9588435c14d9296ba66e29bc7b906e4cecdb4dd9"
Commit e0c2c302 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Do not recreate tf.function on different calls to the same layer

PiperOrigin-RevId: 297366158
parent 01dbd5bf
...@@ -195,11 +195,13 @@ class Transformer(tf.keras.layers.Layer): ...@@ -195,11 +195,13 @@ class Transformer(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash. # TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
call_impl = self.call_impl if not hasattr(self, "_call_impl"):
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions" self._call_impl = self.call_impl
) or tf.compat.v1.executing_eagerly_outside_functions(): if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions"
call_impl = tf.function(experimental_compile=True)(call_impl) ) or tf.compat.v1.executing_eagerly_outside_functions():
return call_impl(inputs) self._call_impl = tf.function(experimental_compile=True)(
self._call_impl)
return self._call_impl(inputs)
def call_impl(self, inputs): def call_impl(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment