"awq/vscode:/vscode.git/clone" did not exist on "44b3187abc3c6836a5561bf72ddca9a62bd81eec"
Commit aa3e639f authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Standardize fine tune checkpoints across all TF2 models.

PiperOrigin-RevId: 372426423
parent f006521b
...@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass pass
@property @property
@abc.abstractmethod def classification_backbone(self):
def supported_sub_model_types(self): raise NotImplementedError(
"""Valid sub model types supported by the get_sub_model function.""" 'Classification backbone not supported for {}'.format(type(self)))
pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256), def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
...@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint). A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
""" """
supported_types = self._feature_extractor.supported_sub_model_types if fine_tune_checkpoint_type == 'detection':
supported_types += ['fine_tune']
if fine_tune_checkpoint_type not in supported_types:
message = ('Checkpoint type "{}" not supported for {}. '
'Supported types are {}')
raise ValueError(
message.format(fine_tune_checkpoint_type,
self._feature_extractor.__class__.__name__,
supported_types))
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint( feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor) _feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model} return {'model': feature_extractor_model}
elif fine_tune_checkpoint_type == 'classification':
return {
'feature_extractor':
self._feature_extractor.classification_backbone
}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
elif fine_tune_checkpoint_type == 'fine_tune':
raise ValueError(('"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'))
else: else:
return {'feature_extractor': self._feature_extractor.get_sub_model( raise ValueError('Unknown fine tune checkpoint type {}'.format(
fine_tune_checkpoint_type)} fine_tune_checkpoint_type))
def updates(self): def updates(self):
if tf_version.is_tf2(): if tf_version.is_tf2():
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
from __future__ import division from __future__ import division
import functools import functools
import re
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
...@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase): ...@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'], self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model) tf.keras.Model)
def test_retore_map_error(self): def test_retore_map_detection(self):
"""Test that restoring unsupported checkpoint type raises an error.""" """Test that detection checkpoints can be restored."""
model = build_center_net_meta_arch(build_resnet=True) model = build_center_net_meta_arch(build_resnet=True)
msg = ("Checkpoint type \"detection\" not supported for " restore_from_objects_map = model.restore_from_objects('detection')
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']") self.assertIsInstance(restore_from_objects_map['model']._feature_extractor,
with self.assertRaisesRegex(ValueError, re.escape(msg)): tf.keras.Model)
model.restore_from_objects('detection')
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor): class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
...@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features= _feature_extractor_for_proposal_features=
self._feature_extractor_for_proposal_features) self._feature_extractor_for_proposal_features)
return {'model': fake_model} return {'model': fake_model}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
else: else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format( raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type)) fine_tune_checkpoint_type))
......
...@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2 ...@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from object_detection.utils import config_util from object_detection.utils import config_util
from object_detection.utils import label_map_util from object_detection.utils import label_map_util
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vutils from object_detection.utils import visualization_utils as vutils
...@@ -587,6 +588,9 @@ def train_loop( ...@@ -587,6 +588,9 @@ def train_loop(
lambda: global_step % num_steps_per_iteration == 0): lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint. # Load a fine-tuning checkpoint.
if train_config.fine_tune_checkpoint: if train_config.fine_tune_checkpoint:
variables_helper.ensure_checkpoint_supported(
train_config.fine_tune_checkpoint, fine_tune_checkpoint_type,
model_dir)
load_fine_tune_checkpoint( load_fine_tune_checkpoint(
detection_model, train_config.fine_tune_checkpoint, detection_model, train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type, fine_tune_checkpoint_version, fine_tune_checkpoint_type, fine_tune_checkpoint_version,
......
...@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor( ...@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses return self._network.num_hourglasses
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs): def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet.""" """The Hourglass-10 backbone for CenterNet."""
......
...@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._network.load_weights(path) self._network.load_weights(path)
def get_base_model(self):
return self._network
def call(self, inputs): def call(self, inputs):
return [self._network(inputs)] return [self._network(inputs)]
...@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return 1 return 1
@property @property
def supported_sub_model_types(self): def classification_backbone(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering, def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
......
...@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
return 4 return 4
@property @property
def supported_sub_model_types(self): def classification_backbone(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs): def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
......
...@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
return 4 return 4
@property @property
def supported_sub_model_types(self): def classification_backbone(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
......
...@@ -40,12 +40,26 @@ message TrainConfig { ...@@ -40,12 +40,26 @@ message TrainConfig {
// extractor variables trained outside of object detection. // extractor variables trained outside of object detection.
optional string fine_tune_checkpoint = 7 [default=""]; optional string fine_tune_checkpoint = 7 [default=""];
// Type of checkpoint to restore variables from, e.g. 'classification' // This option controls how variables are restored from the (pre-trained)
// 'detection', `fine_tune`, `full`. Controls which variables are restored // fine_tune_checkpoint. For TF2 models, 3 different types are supported:
// from the pre-trained checkpoint. For meta architecture specific valid // 1. "classification": Restores only the classification backbone part of
// values of this parameter, see the restore_map (TF1) or // the feature extractor. This option is typically used when you want
// to train a detection model starting from a pre-trained image
// classification model, e.g. a ResNet model pre-trained on ImageNet.
// 2. "detection": Restores the entire feature extractor. The only parts
// of the full detection model that are not restored are the box and
// class prediction heads. This option is typically used when you want
// to use a pre-trained detection model and train on a new dataset or
// task which requires different box and class prediction heads.
// 3. "full": Restores the entire detection model, including the
// feature extractor, its classification backbone, and the prediction
// heads. This option should only be used when the pre-training and
// fine-tuning tasks are the same. Otherwise, the model's parameters
// may have incompatible shapes, which will cause errors when
// attempting to restore the checkpoint.
// For more details about this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the // restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files // /meta_architectures/*meta_arch.py files.
optional string fine_tune_checkpoint_type = 22 [default=""]; optional string fine_tune_checkpoint_type = 22 [default=""];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow // Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
......
...@@ -21,6 +21,7 @@ from __future__ import division ...@@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import logging import logging
import os
import re import re
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -29,6 +30,19 @@ import tf_slim as slim ...@@ -29,6 +30,19 @@ import tf_slim as slim
from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variables as tf_variables
# Maps checkpoint types to variable name prefixes that are no longer
# supported
DETECTION_FEATURE_EXTRACTOR_MSG = """\
The checkpoint type 'detection' is not supported when it contains variable
names with 'feature_extractor'. Please download the new checkpoint file
from model zoo.
"""
DEPRECATED_CHECKPOINT_MAP = {
'detection': ('feature_extractor', DETECTION_FEATURE_EXTRACTOR_MSG)
}
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in # TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py # tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False): def filter_variables(variables, filter_regex_list, invert=False):
...@@ -176,3 +190,41 @@ def get_global_variables_safely(): ...@@ -176,3 +190,41 @@ def get_global_variables_safely():
"executing eagerly. Use a Keras model's `.variables` " "executing eagerly. Use a Keras model's `.variables` "
"attribute instead.") "attribute instead.")
return tf.global_variables() return tf.global_variables()
def ensure_checkpoint_supported(checkpoint_path, checkpoint_type, model_dir):
"""Ensures that the given checkpoint can be properly loaded.
Performs the following checks
1. Raises an error if checkpoint_path and model_dir are same.
2. Checks that checkpoint_path does not contain a deprecated checkpoint file
by inspecting its variables.
Args:
checkpoint_path: str, path to checkpoint.
checkpoint_type: str, denotes the type of checkpoint.
model_dir: The model directory to store intermediate training checkpoints.
Raises:
RuntimeError: If
1. We detect an deprecated checkpoint file.
2. model_dir and checkpoint_path are in the same directory.
"""
variables = tf.train.list_variables(checkpoint_path)
if checkpoint_type in DEPRECATED_CHECKPOINT_MAP:
blocked_prefix, msg = DEPRECATED_CHECKPOINT_MAP[checkpoint_type]
for var_name, _ in variables:
if var_name.startswith(blocked_prefix):
tf.logging.error('Found variable name - %s with prefix %s', var_name,
blocked_prefix)
raise RuntimeError(msg)
checkpoint_path_dir = os.path.abspath(os.path.dirname(checkpoint_path))
model_dir = os.path.abspath(model_dir)
if model_dir == checkpoint_path_dir:
raise RuntimeError(
('Checkpoint dir ({}) and model_dir ({}) cannot be same.'.format(
checkpoint_path_dir, model_dir) +
(' Please set model_dir to a different path.')))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment