"docs/vscode:/vscode.git/clone" did not exist on "074e12358bc17e7dbe111ea4f62f05dbae8a49d5"
Commit fe1fa6d6 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Only apply tf.function(experimental_compile=True) in eager mode

Application in graph mode still leads to some crashes.

PiperOrigin-RevId: 297144398
parent 1d6c8833
...@@ -193,8 +193,15 @@ class Transformer(tf.keras.layers.Layer): ...@@ -193,8 +193,15 @@ class Transformer(tf.keras.layers.Layer):
base_config = super(Transformer, self).get_config() base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@tf.function(experimental_compile=True)
def call(self, inputs): def call(self, inputs):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
call_impl = self.call_impl
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions"
) or tf.compat.v1.executing_eagerly_outside_functions():
call_impl = tf.function(experimental_compile=True)(call_impl)
return call_impl(inputs)
def call_impl(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs input_tensor, attention_mask = inputs
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment