Commit aa3e639f authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Standardize fine tune checkpoints across all TF2 models.

PiperOrigin-RevId: 372426423
parent f006521b
......@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass
@property
@abc.abstractmethod
def supported_sub_model_types(self):
"""Valid sub model types supported by the get_sub_model function."""
pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def classification_backbone(self):
raise NotImplementedError(
'Classification backbone not supported for {}'.format(type(self)))
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
......@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
supported_types = self._feature_extractor.supported_sub_model_types
supported_types += ['fine_tune']
if fine_tune_checkpoint_type not in supported_types:
message = ('Checkpoint type "{}" not supported for {}. '
'Supported types are {}')
raise ValueError(
message.format(fine_tune_checkpoint_type,
self._feature_extractor.__class__.__name__,
supported_types))
elif fine_tune_checkpoint_type == 'fine_tune':
if fine_tune_checkpoint_type == 'detection':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
elif fine_tune_checkpoint_type == 'classification':
return {
'feature_extractor':
self._feature_extractor.classification_backbone
}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
elif fine_tune_checkpoint_type == 'fine_tune':
raise ValueError(('"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'))
else:
return {'feature_extractor': self._feature_extractor.get_sub_model(
fine_tune_checkpoint_type)}
raise ValueError('Unknown fine tune checkpoint type {}'.format(
fine_tune_checkpoint_type))
def updates(self):
if tf_version.is_tf2():
......
......@@ -17,7 +17,6 @@
from __future__ import division
import functools
import re
import unittest
from absl.testing import parameterized
......@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model)
def test_retore_map_error(self):
"""Test that restoring unsupported checkpoint type raises an error."""
def test_retore_map_detection(self):
"""Test that detection checkpoints can be restored."""
model = build_center_net_meta_arch(build_resnet=True)
msg = ("Checkpoint type \"detection\" not supported for "
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection')
restore_from_objects_map = model.restore_from_objects('detection')
self.assertIsInstance(restore_from_objects_map['model']._feature_extractor,
tf.keras.Model)
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
......@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features=
self._feature_extractor_for_proposal_features)
return {'model': fake_model}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))
......
......@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import ops
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vutils
......@@ -587,6 +588,9 @@ def train_loop(
lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint.
if train_config.fine_tune_checkpoint:
variables_helper.ensure_checkpoint_supported(
train_config.fine_tune_checkpoint, fine_tune_checkpoint_type,
model_dir)
load_fine_tune_checkpoint(
detection_model, train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type, fine_tune_checkpoint_version,
......
......@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet."""
......
......@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def load_feature_extractor_weights(self, path):
self._network.load_weights(path)
def get_base_model(self):
return self._network
def call(self, inputs):
return [self._network(inputs)]
......@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return 1
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def classification_backbone(self):
return self._network
def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
......
......@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
return 4
@property
def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def classification_backbone(self):
return self._base_model
def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
......
......@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
return 4
@property
def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def classification_backbone(self):
return self._base_model
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
......
......@@ -40,12 +40,26 @@ message TrainConfig {
// extractor variables trained outside of object detection.
optional string fine_tune_checkpoint = 7 [default=""];
// Type of checkpoint to restore variables from, e.g. 'classification'
// 'detection', `fine_tune`, `full`. Controls which variables are restored
// from the pre-trained checkpoint. For meta architecture specific valid
// values of this parameter, see the restore_map (TF1) or
// This option controls how variables are restored from the (pre-trained)
// fine_tune_checkpoint. For TF2 models, 3 different types are supported:
// 1. "classification": Restores only the classification backbone part of
// the feature extractor. This option is typically used when you want
// to train a detection model starting from a pre-trained image
// classification model, e.g. a ResNet model pre-trained on ImageNet.
// 2. "detection": Restores the entire feature extractor. The only parts
// of the full detection model that are not restored are the box and
// class prediction heads. This option is typically used when you want
// to use a pre-trained detection model and train on a new dataset or
// task which requires different box and class prediction heads.
// 3. "full": Restores the entire detection model, including the
// feature extractor, its classification backbone, and the prediction
// heads. This option should only be used when the pre-training and
// fine-tuning tasks are the same. Otherwise, the model's parameters
// may have incompatible shapes, which will cause errors when
// attempting to restore the checkpoint.
// For more details about this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files
// /meta_architectures/*meta_arch.py files.
optional string fine_tune_checkpoint_type = 22 [default=""];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
......
......@@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import logging
import os
import re
import tensorflow.compat.v1 as tf
......@@ -29,6 +30,19 @@ import tf_slim as slim
from tensorflow.python.ops import variables as tf_variables
# Maps checkpoint types to variable name prefixes that are no longer
# supported
DETECTION_FEATURE_EXTRACTOR_MSG = """\
The checkpoint type 'detection' is not supported when it contains variable
names with 'feature_extractor'. Please download the new checkpoint file
from model zoo.
"""
DEPRECATED_CHECKPOINT_MAP = {
'detection': ('feature_extractor', DETECTION_FEATURE_EXTRACTOR_MSG)
}
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False):
......@@ -176,3 +190,41 @@ def get_global_variables_safely():
"executing eagerly. Use a Keras model's `.variables` "
"attribute instead.")
return tf.global_variables()
def ensure_checkpoint_supported(checkpoint_path, checkpoint_type, model_dir):
"""Ensures that the given checkpoint can be properly loaded.
Performs the following checks
1. Raises an error if checkpoint_path and model_dir are same.
2. Checks that checkpoint_path does not contain a deprecated checkpoint file
by inspecting its variables.
Args:
checkpoint_path: str, path to checkpoint.
checkpoint_type: str, denotes the type of checkpoint.
model_dir: The model directory to store intermediate training checkpoints.
Raises:
RuntimeError: If
1. We detect an deprecated checkpoint file.
2. model_dir and checkpoint_path are in the same directory.
"""
variables = tf.train.list_variables(checkpoint_path)
if checkpoint_type in DEPRECATED_CHECKPOINT_MAP:
blocked_prefix, msg = DEPRECATED_CHECKPOINT_MAP[checkpoint_type]
for var_name, _ in variables:
if var_name.startswith(blocked_prefix):
tf.logging.error('Found variable name - %s with prefix %s', var_name,
blocked_prefix)
raise RuntimeError(msg)
checkpoint_path_dir = os.path.abspath(os.path.dirname(checkpoint_path))
model_dir = os.path.abspath(model_dir)
if model_dir == checkpoint_path_dir:
raise RuntimeError(
('Checkpoint dir ({}) and model_dir ({}) cannot be same.'.format(
checkpoint_path_dir, model_dir) +
(' Please set model_dir to a different path.')))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment