Commit 8162ba7a authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Update layers to handle None input shape

PiperOrigin-RevId: 362313154
parent 288340b6
...@@ -590,7 +590,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -590,7 +590,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
# regular global average pooling. # regular global average pooling.
# Shape: [batch_size, 1, 1, 1, channels] # Shape: [batch_size, 1, 1, 1, channels]
x = tf.reduce_sum(inputs, axis=(1, 2, 3), keepdims=True) x = tf.reduce_sum(inputs, axis=(1, 2, 3), keepdims=True)
x = x / tf.cast(inputs.shape[2] * inputs.shape[3], x.dtype) x = x / tf.cast(tf.shape(inputs)[2] * tf.shape(inputs)[3], x.dtype)
x = x + buffer x = x + buffer
# Shape: [batch_size, 1, 1, 1, channels] # Shape: [batch_size, 1, 1, 1, channels]
...@@ -713,7 +713,7 @@ class CausalConvMixin: ...@@ -713,7 +713,7 @@ class CausalConvMixin:
# When buffer padding, use 'valid' padding across time. The output shape # When buffer padding, use 'valid' padding across time. The output shape
# across time should be the input shape minus any padding, assuming # across time should be the input shape minus any padding, assuming
# the stride across time is 1. # the stride across time is 1.
if self._use_buffered_input: if self._use_buffered_input and spatial_output_shape[0] is not None:
padding = self._compute_buffered_causal_padding(use_buffered_input=False) padding = self._compute_buffered_causal_padding(use_buffered_input=False)
spatial_output_shape[0] -= sum(padding[1]) spatial_output_shape[0] -= sum(padding[1])
return spatial_output_shape return spatial_output_shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment