Commit e0e8aa4a authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Update type annotations and register classes.

PiperOrigin-RevId: 361632206
parent 161ebfc7
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -509,7 +509,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -509,7 +509,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
def call(self, def call(self,
inputs: tf.Tensor, inputs: tf.Tensor,
states: Optional[States] = None, states: Optional[States] = None,
output_states: bool = True) -> Union[Any, Tuple[Any, States]]: output_states: bool = True
) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
"""Calls the layer with the given inputs. """Calls the layer with the given inputs.
Args: Args:
...@@ -589,6 +590,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -589,6 +590,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
return (x, states) if output_states else x return (x, states) if output_states else x
@tf.keras.utils.register_keras_serializable(package='Vision')
class SpatialAveragePool3D(tf.keras.layers.Layer): class SpatialAveragePool3D(tf.keras.layers.Layer):
"""Global average pooling layer pooling across spatial dimentions. """Global average pooling layer pooling across spatial dimentions.
""" """
...@@ -704,6 +706,7 @@ class CausalConvMixin: ...@@ -704,6 +706,7 @@ class CausalConvMixin:
return spatial_output_shape return spatial_output_shape
@tf.keras.utils.register_keras_serializable(package='Vision')
class Conv2D(tf.keras.layers.Conv2D, CausalConvMixin): class Conv2D(tf.keras.layers.Conv2D, CausalConvMixin):
"""Conv2D layer supporting CausalConv. """Conv2D layer supporting CausalConv.
...@@ -751,6 +754,7 @@ class Conv2D(tf.keras.layers.Conv2D, CausalConvMixin): ...@@ -751,6 +754,7 @@ class Conv2D(tf.keras.layers.Conv2D, CausalConvMixin):
return self._buffered_spatial_output_shape(shape) return self._buffered_spatial_output_shape(shape)
@tf.keras.utils.register_keras_serializable(package='Vision')
class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, CausalConvMixin): class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, CausalConvMixin):
"""DepthwiseConv2D layer supporting CausalConv. """DepthwiseConv2D layer supporting CausalConv.
...@@ -812,6 +816,7 @@ class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, CausalConvMixin): ...@@ -812,6 +816,7 @@ class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, CausalConvMixin):
return self._buffered_spatial_output_shape(shape) return self._buffered_spatial_output_shape(shape)
@tf.keras.utils.register_keras_serializable(package='Vision')
class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin): class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin):
"""Conv3D layer supporting CausalConv. """Conv3D layer supporting CausalConv.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment