Commit c32ce7cf authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410640806
parent 51cb03b0
......@@ -884,7 +884,6 @@ class AssembleNet(tf.keras.Model):
inputs=original_inputs, outputs=streams, **kwargs)
@tf.keras.utils.register_keras_serializable(package='Vision')
class AssembleNetModel(tf.keras.Model):
"""An AssembleNet model builder."""
......
......@@ -48,7 +48,6 @@ HOURGLASS_SPECS = {
}
@tf.keras.utils.register_keras_serializable(package='centernet')
class Hourglass(tf.keras.Model):
"""CenterNet Hourglass backbone."""
......
......@@ -21,7 +21,6 @@ import tensorflow as tf
from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetHead(tf.keras.Model):
"""CenterNet Head."""
......
......@@ -123,7 +123,6 @@ def _make_repeated_residual_blocks(
return tf.keras.Sequential(blocks)
@tf.keras.utils.register_keras_serializable(package='centernet')
class HourglassBlock(tf.keras.layers.Layer):
"""Hourglass module: an encoder-decoder block."""
......@@ -274,7 +273,6 @@ class HourglassBlock(tf.keras.layers.Layer):
return config
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetHeadConv(tf.keras.layers.Layer):
"""Convolution block for the CenterNet head."""
......
......@@ -30,7 +30,6 @@ from official.vision.beta.projects.centernet.ops import loss_ops
from official.vision.beta.projects.centernet.ops import nms_ops
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetDetectionGenerator(tf.keras.layers.Layer):
"""CenterNet Detection Generator."""
......
......@@ -23,7 +23,6 @@ from official.modeling import tf_utils
from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import hourglass_network
@tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskHead(tf.keras.layers.Layer):
"""Creates a mask head."""
......
......@@ -31,7 +31,6 @@ def resize_as(source, size):
return tf.transpose(source, (0, 3, 1, 2))
@tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Mask R-CNN model."""
......
......@@ -25,7 +25,6 @@ import tensorflow as tf
from official.vision.beta.projects.example import example_config as example_cfg
@tf.keras.utils.register_keras_serializable(package='Vision')
class ExampleModel(tf.keras.Model):
"""A example model class.
......
......@@ -21,7 +21,6 @@ import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model."""
......
......@@ -24,7 +24,6 @@ regularizers = tf.keras.regularizers
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='simclr')
class ProjectionHead(tf.keras.layers.Layer):
"""Projection head."""
......@@ -144,7 +143,6 @@ class ProjectionHead(tf.keras.layers.Layer):
return proj_head_output, proj_finetune_output
@tf.keras.utils.register_keras_serializable(package='simclr')
class ClassificationHead(tf.keras.layers.Layer):
"""Classification Head."""
......
......@@ -57,7 +57,6 @@ def cross_replica_concat(tensor: tf.Tensor, num_replicas: int) -> tf.Tensor:
return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
@tf.keras.utils.register_keras_serializable(package='simclr')
class ContrastiveLoss(object):
"""Contrastive training loss function."""
......
......@@ -22,7 +22,6 @@ from official.modeling import tf_utils
regularizers = tf.keras.regularizers
@tf.keras.utils.register_keras_serializable(package='simclr')
class DenseBN(tf.keras.layers.Layer):
"""Modified Dense layer to help build simclr system.
......
......@@ -27,7 +27,6 @@ PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
@tf.keras.utils.register_keras_serializable(package='simclr')
class SimCLRModel(tf.keras.Model):
"""A classification model based on SimCLR framework."""
......
......@@ -27,7 +27,6 @@ from official.vision.beta.projects.video_ssl.configs import video_ssl as video_s
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class VideoSSLModel(tf.keras.Model):
"""A video ssl model class builder."""
......
......@@ -371,7 +371,6 @@ BACKBONES = {
}
@tf.keras.utils.register_keras_serializable(package='yolo')
class Darknet(tf.keras.Model):
"""The Darknet backbone architecture."""
......
......@@ -84,14 +84,12 @@ YOLO_MODELS = {
}
@tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer):
def call(self, inputs): # pylint: disable=arguments-differ
return None, inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloFPN(tf.keras.layers.Layer):
"""YOLO Feature pyramid network."""
......@@ -248,7 +246,6 @@ class YoloFPN(tf.keras.layers.Layer):
return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network."""
......@@ -441,7 +438,6 @@ class YoloPAN(tf.keras.layers.Layer):
return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloDecoder(tf.keras.Model):
"""Darknet Backbone Decoder."""
......
......@@ -21,7 +21,6 @@ from official.vision.beta.projects.yolo.ops import box_ops
from official.vision.beta.projects.yolo.ops import loss_utils
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloLayer(tf.keras.Model):
"""Yolo layer (detection generator)."""
......
......@@ -21,14 +21,12 @@ from official.modeling import tf_utils
from official.vision.beta.ops import spatial_transform_ops
@tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(tf.keras.layers.Layer):
def call(self, inputs):
return inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class ConvBN(tf.keras.layers.Layer):
"""ConvBN block.
......@@ -241,7 +239,6 @@ class ConvBN(tf.keras.layers.Layer):
return layer_config
@tf.keras.utils.register_keras_serializable(package='yolo')
class DarkResidual(tf.keras.layers.Layer):
"""Darknet block with Residual connection for Yolo v3 Backbone."""
......@@ -406,7 +403,6 @@ class DarkResidual(tf.keras.layers.Layer):
return layer_config
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPTiny(tf.keras.layers.Layer):
"""CSP Tiny layer.
......@@ -556,7 +552,6 @@ class CSPTiny(tf.keras.layers.Layer):
return x, x5
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPRoute(tf.keras.layers.Layer):
"""CSPRoute block.
......@@ -696,7 +691,6 @@ class CSPRoute(tf.keras.layers.Layer):
return (x, y)
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(tf.keras.layers.Layer):
"""CSPConnect block.
......@@ -941,7 +935,6 @@ class CSPStack(tf.keras.layers.Layer):
return x
@tf.keras.utils.register_keras_serializable(package='yolo')
class PathAggregationBlock(tf.keras.layers.Layer):
"""Path Aggregation block."""
......@@ -1132,7 +1125,6 @@ class PathAggregationBlock(tf.keras.layers.Layer):
return self._call_regular(inputs, training=training)
@tf.keras.utils.register_keras_serializable(package='yolo')
class SPP(tf.keras.layers.Layer):
"""Spatial Pyramid Pooling.
......@@ -1411,7 +1403,6 @@ class CBAM(tf.keras.layers.Layer):
return self._sam(self._cam(inputs))
@tf.keras.utils.register_keras_serializable(package='yolo')
class DarkRouteProcess(tf.keras.layers.Layer):
"""Dark Route Process block.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment