Commit d87fd224 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 362396049
parent d2a9e4c2
......@@ -59,28 +59,10 @@ class ClassificationModel(tf.keras.Model):
skip_logits_layer: `bool`, whether to skip the prediction layer.
**kwargs: keyword arguments to be passed.
"""
self._self_setattr_tracking = False
self._config_dict = {
'backbone': backbone,
'num_classes': num_classes,
'input_specs': input_specs,
'dropout_rate': dropout_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'add_head_batch_norm': add_head_batch_norm,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
}
self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._backbone = backbone
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
norm = tf.keras.layers.BatchNormalization
axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1
inputs = tf.keras.Input(shape=input_specs.shape[1:])
......@@ -88,18 +70,37 @@ class ClassificationModel(tf.keras.Model):
x = endpoints[max(endpoints.keys())]
if add_head_batch_norm:
x = self._norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
x = norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
if not skip_logits_layer:
x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(
num_classes, kernel_initializer=kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
num_classes,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer)(
x)
super(ClassificationModel, self).__init__(
inputs=inputs, outputs=x, **kwargs)
self._config_dict = {
'backbone': backbone,
'num_classes': num_classes,
'input_specs': input_specs,
'dropout_rate': dropout_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'add_head_batch_norm': add_head_batch_norm,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
}
self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._backbone = backbone
self._norm = norm
@property
def checkpoint_items(self):
......
......@@ -74,6 +74,9 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module()
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
self._export_from_module(module, input_type, tmp_dir)
......@@ -96,6 +99,10 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
shape=[224, 224, 3], dtype=tf.float32)))
expected_output = module.model(processed_images, training=False)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy())
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment