Commit c32ce7cf authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410640806
parent 51cb03b0
...@@ -884,7 +884,6 @@ class AssembleNet(tf.keras.Model): ...@@ -884,7 +884,6 @@ class AssembleNet(tf.keras.Model):
inputs=original_inputs, outputs=streams, **kwargs) inputs=original_inputs, outputs=streams, **kwargs)
@tf.keras.utils.register_keras_serializable(package='Vision')
class AssembleNetModel(tf.keras.Model): class AssembleNetModel(tf.keras.Model):
"""An AssembleNet model builder.""" """An AssembleNet model builder."""
......
...@@ -48,7 +48,6 @@ HOURGLASS_SPECS = { ...@@ -48,7 +48,6 @@ HOURGLASS_SPECS = {
} }
@tf.keras.utils.register_keras_serializable(package='centernet')
class Hourglass(tf.keras.Model): class Hourglass(tf.keras.Model):
"""CenterNet Hourglass backbone.""" """CenterNet Hourglass backbone."""
......
...@@ -21,7 +21,6 @@ import tensorflow as tf ...@@ -21,7 +21,6 @@ import tensorflow as tf
from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetHead(tf.keras.Model): class CenterNetHead(tf.keras.Model):
"""CenterNet Head.""" """CenterNet Head."""
......
...@@ -123,7 +123,6 @@ def _make_repeated_residual_blocks( ...@@ -123,7 +123,6 @@ def _make_repeated_residual_blocks(
return tf.keras.Sequential(blocks) return tf.keras.Sequential(blocks)
@tf.keras.utils.register_keras_serializable(package='centernet')
class HourglassBlock(tf.keras.layers.Layer): class HourglassBlock(tf.keras.layers.Layer):
"""Hourglass module: an encoder-decoder block.""" """Hourglass module: an encoder-decoder block."""
...@@ -274,7 +273,6 @@ class HourglassBlock(tf.keras.layers.Layer): ...@@ -274,7 +273,6 @@ class HourglassBlock(tf.keras.layers.Layer):
return config return config
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetHeadConv(tf.keras.layers.Layer): class CenterNetHeadConv(tf.keras.layers.Layer):
"""Convolution block for the CenterNet head.""" """Convolution block for the CenterNet head."""
......
...@@ -30,7 +30,6 @@ from official.vision.beta.projects.centernet.ops import loss_ops ...@@ -30,7 +30,6 @@ from official.vision.beta.projects.centernet.ops import loss_ops
from official.vision.beta.projects.centernet.ops import nms_ops from official.vision.beta.projects.centernet.ops import nms_ops
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetDetectionGenerator(tf.keras.layers.Layer): class CenterNetDetectionGenerator(tf.keras.layers.Layer):
"""CenterNet Detection Generator.""" """CenterNet Detection Generator."""
......
...@@ -23,7 +23,6 @@ from official.modeling import tf_utils ...@@ -23,7 +23,6 @@ from official.modeling import tf_utils
from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import hourglass_network from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import hourglass_network
@tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskHead(tf.keras.layers.Layer): class DeepMaskHead(tf.keras.layers.Layer):
"""Creates a mask head.""" """Creates a mask head."""
......
...@@ -31,7 +31,6 @@ def resize_as(source, size): ...@@ -31,7 +31,6 @@ def resize_as(source, size):
return tf.transpose(source, (0, 3, 1, 2)) return tf.transpose(source, (0, 3, 1, 2))
@tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel): class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Mask R-CNN model.""" """The Mask R-CNN model."""
......
...@@ -25,7 +25,6 @@ import tensorflow as tf ...@@ -25,7 +25,6 @@ import tensorflow as tf
from official.vision.beta.projects.example import example_config as example_cfg from official.vision.beta.projects.example import example_config as example_cfg
@tf.keras.utils.register_keras_serializable(package='Vision')
class ExampleModel(tf.keras.Model): class ExampleModel(tf.keras.Model):
"""A example model class. """A example model class.
......
...@@ -21,7 +21,6 @@ import tensorflow as tf ...@@ -21,7 +21,6 @@ import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model from official.vision.beta.modeling import maskrcnn_model
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model.""" """The Panoptic Segmentation model."""
......
...@@ -24,7 +24,6 @@ regularizers = tf.keras.regularizers ...@@ -24,7 +24,6 @@ regularizers = tf.keras.regularizers
layers = tf.keras.layers layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='simclr')
class ProjectionHead(tf.keras.layers.Layer): class ProjectionHead(tf.keras.layers.Layer):
"""Projection head.""" """Projection head."""
...@@ -144,7 +143,6 @@ class ProjectionHead(tf.keras.layers.Layer): ...@@ -144,7 +143,6 @@ class ProjectionHead(tf.keras.layers.Layer):
return proj_head_output, proj_finetune_output return proj_head_output, proj_finetune_output
@tf.keras.utils.register_keras_serializable(package='simclr')
class ClassificationHead(tf.keras.layers.Layer): class ClassificationHead(tf.keras.layers.Layer):
"""Classification Head.""" """Classification Head."""
......
...@@ -57,7 +57,6 @@ def cross_replica_concat(tensor: tf.Tensor, num_replicas: int) -> tf.Tensor: ...@@ -57,7 +57,6 @@ def cross_replica_concat(tensor: tf.Tensor, num_replicas: int) -> tf.Tensor:
return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:]) return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
@tf.keras.utils.register_keras_serializable(package='simclr')
class ContrastiveLoss(object): class ContrastiveLoss(object):
"""Contrastive training loss function.""" """Contrastive training loss function."""
......
...@@ -22,7 +22,6 @@ from official.modeling import tf_utils ...@@ -22,7 +22,6 @@ from official.modeling import tf_utils
regularizers = tf.keras.regularizers regularizers = tf.keras.regularizers
@tf.keras.utils.register_keras_serializable(package='simclr')
class DenseBN(tf.keras.layers.Layer): class DenseBN(tf.keras.layers.Layer):
"""Modified Dense layer to help build simclr system. """Modified Dense layer to help build simclr system.
......
...@@ -27,7 +27,6 @@ PROJECTION_OUTPUT_KEY = 'projection_outputs' ...@@ -27,7 +27,6 @@ PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs' SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
@tf.keras.utils.register_keras_serializable(package='simclr')
class SimCLRModel(tf.keras.Model): class SimCLRModel(tf.keras.Model):
"""A classification model based on SimCLR framework.""" """A classification model based on SimCLR framework."""
......
...@@ -27,7 +27,6 @@ from official.vision.beta.projects.video_ssl.configs import video_ssl as video_s ...@@ -27,7 +27,6 @@ from official.vision.beta.projects.video_ssl.configs import video_ssl as video_s
layers = tf.keras.layers layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class VideoSSLModel(tf.keras.Model): class VideoSSLModel(tf.keras.Model):
"""A video ssl model class builder.""" """A video ssl model class builder."""
......
...@@ -371,7 +371,6 @@ BACKBONES = { ...@@ -371,7 +371,6 @@ BACKBONES = {
} }
@tf.keras.utils.register_keras_serializable(package='yolo')
class Darknet(tf.keras.Model): class Darknet(tf.keras.Model):
"""The Darknet backbone architecture.""" """The Darknet backbone architecture."""
......
...@@ -84,14 +84,12 @@ YOLO_MODELS = { ...@@ -84,14 +84,12 @@ YOLO_MODELS = {
} }
@tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer): class _IdentityRoute(tf.keras.layers.Layer):
def call(self, inputs): # pylint: disable=arguments-differ def call(self, inputs): # pylint: disable=arguments-differ
return None, inputs return None, inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloFPN(tf.keras.layers.Layer): class YoloFPN(tf.keras.layers.Layer):
"""YOLO Feature pyramid network.""" """YOLO Feature pyramid network."""
...@@ -248,7 +246,6 @@ class YoloFPN(tf.keras.layers.Layer): ...@@ -248,7 +246,6 @@ class YoloFPN(tf.keras.layers.Layer):
return outputs return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer): class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network.""" """YOLO Path Aggregation Network."""
...@@ -441,7 +438,6 @@ class YoloPAN(tf.keras.layers.Layer): ...@@ -441,7 +438,6 @@ class YoloPAN(tf.keras.layers.Layer):
return outputs return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloDecoder(tf.keras.Model): class YoloDecoder(tf.keras.Model):
"""Darknet Backbone Decoder.""" """Darknet Backbone Decoder."""
......
...@@ -21,7 +21,6 @@ from official.vision.beta.projects.yolo.ops import box_ops ...@@ -21,7 +21,6 @@ from official.vision.beta.projects.yolo.ops import box_ops
from official.vision.beta.projects.yolo.ops import loss_utils from official.vision.beta.projects.yolo.ops import loss_utils
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloLayer(tf.keras.Model): class YoloLayer(tf.keras.Model):
"""Yolo layer (detection generator).""" """Yolo layer (detection generator)."""
......
...@@ -21,14 +21,12 @@ from official.modeling import tf_utils ...@@ -21,14 +21,12 @@ from official.modeling import tf_utils
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
@tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(tf.keras.layers.Layer): class Identity(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
return inputs return inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class ConvBN(tf.keras.layers.Layer): class ConvBN(tf.keras.layers.Layer):
"""ConvBN block. """ConvBN block.
...@@ -241,7 +239,6 @@ class ConvBN(tf.keras.layers.Layer): ...@@ -241,7 +239,6 @@ class ConvBN(tf.keras.layers.Layer):
return layer_config return layer_config
@tf.keras.utils.register_keras_serializable(package='yolo')
class DarkResidual(tf.keras.layers.Layer): class DarkResidual(tf.keras.layers.Layer):
"""Darknet block with Residual connection for Yolo v3 Backbone.""" """Darknet block with Residual connection for Yolo v3 Backbone."""
...@@ -406,7 +403,6 @@ class DarkResidual(tf.keras.layers.Layer): ...@@ -406,7 +403,6 @@ class DarkResidual(tf.keras.layers.Layer):
return layer_config return layer_config
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPTiny(tf.keras.layers.Layer): class CSPTiny(tf.keras.layers.Layer):
"""CSP Tiny layer. """CSP Tiny layer.
...@@ -556,7 +552,6 @@ class CSPTiny(tf.keras.layers.Layer): ...@@ -556,7 +552,6 @@ class CSPTiny(tf.keras.layers.Layer):
return x, x5 return x, x5
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPRoute(tf.keras.layers.Layer): class CSPRoute(tf.keras.layers.Layer):
"""CSPRoute block. """CSPRoute block.
...@@ -696,7 +691,6 @@ class CSPRoute(tf.keras.layers.Layer): ...@@ -696,7 +691,6 @@ class CSPRoute(tf.keras.layers.Layer):
return (x, y) return (x, y)
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(tf.keras.layers.Layer): class CSPConnect(tf.keras.layers.Layer):
"""CSPConnect block. """CSPConnect block.
...@@ -941,7 +935,6 @@ class CSPStack(tf.keras.layers.Layer): ...@@ -941,7 +935,6 @@ class CSPStack(tf.keras.layers.Layer):
return x return x
@tf.keras.utils.register_keras_serializable(package='yolo')
class PathAggregationBlock(tf.keras.layers.Layer): class PathAggregationBlock(tf.keras.layers.Layer):
"""Path Aggregation block.""" """Path Aggregation block."""
...@@ -1132,7 +1125,6 @@ class PathAggregationBlock(tf.keras.layers.Layer): ...@@ -1132,7 +1125,6 @@ class PathAggregationBlock(tf.keras.layers.Layer):
return self._call_regular(inputs, training=training) return self._call_regular(inputs, training=training)
@tf.keras.utils.register_keras_serializable(package='yolo')
class SPP(tf.keras.layers.Layer): class SPP(tf.keras.layers.Layer):
"""Spatial Pyramid Pooling. """Spatial Pyramid Pooling.
...@@ -1411,7 +1403,6 @@ class CBAM(tf.keras.layers.Layer): ...@@ -1411,7 +1403,6 @@ class CBAM(tf.keras.layers.Layer):
return self._sam(self._cam(inputs)) return self._sam(self._cam(inputs))
@tf.keras.utils.register_keras_serializable(package='yolo')
class DarkRouteProcess(tf.keras.layers.Layer): class DarkRouteProcess(tf.keras.layers.Layer):
"""Dark Route Process block. """Dark Route Process block.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment