Commit ef21dabb authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Remove private Keras API usage.

PiperOrigin-RevId: 382833629
parent c1e086ef
...@@ -941,15 +941,13 @@ class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin): ...@@ -941,15 +941,13 @@ class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin):
base_config = super(Conv3D, self).get_config() base_config = super(Conv3D, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape): def call(self, inputs):
"""Builds the layer with the given input shape.""" """Call the layer with the given inputs."""
super(Conv3D, self).build(input_shape) # Note: tf.nn.conv3d with depthwise kernels on CPU is currently only
# supported when compiling with TF graph (XLA) using tf.function, so it
# TODO(b/177662019): tf.nn.conv3d with depthwise kernels on CPU # is compiled by default here (b/186463870).
# in eager mode may produce incorrect output or cause a segfault. conv_fn = tf.function(super(Conv3D, self).call, jit_compile=True)
# To avoid this issue, compile the op to TF graph using tf.function. return conv_fn(inputs)
self._convolution_op = tf.function(
self._convolution_op, experimental_compile=True)
def _compute_causal_padding(self, inputs): def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs.""" """Computes causal padding dimensions for the given inputs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment