"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "2c63b5cd2eeaf66c3a45e7c65da41d16fb8838ca"
Commit fd6987fa authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Support fine_tune in all CenterNet feature extractors.

PiperOrigin-RevId: 326528933
parent f41f14e6
......@@ -118,6 +118,12 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor."""
pass
@property
@abc.abstractmethod
def supported_sub_model_types(self):
"""Valid sub model types supported by the get_sub_model function."""
pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
......@@ -2974,22 +2980,47 @@ class CenterNetMetaArch(model.DetectionModel):
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
'detection': used when loading in the Hourglass model pre-trained on
other detection task.
'classification': used when loading in the ResNet model pre-trained on
image classification task. Note that only the image feature encoding
part is loaded but not those upsampling layers.
Valid values: `detection`, `classification`, `fine_tune`.
Default 'detection'.
'detection': used when loading models pre-trained on other detection
tasks. With this checkpoint type the weights of the feature extractor
are expected under the attribute 'feature_extractor'.
'classification': used when loading models pre-trained on an image
classification task. Note that only the encoder section of the network
is loaded and not the upsampling layers. With this checkpoint type,
the weights of only the encoder section are expected under the
attribute 'feature_extractor'.
'fine_tune': used when loading the entire CenterNet feature extractor
pre-trained on other tasks. The checkpoints saved during CenterNet
model training can be directly loaded using this mode.
model training can be directly loaded using this type. With this
checkpoint type, the weights of the feature extractor are expected
under the attribute 'model._feature_extractor'.
For more details, see the tensorflow section on Loading mechanics.
https://www.tensorflow.org/guide/checkpoint#loading_mechanics
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
sub_model = self._feature_extractor.get_sub_model(fine_tune_checkpoint_type)
return {'feature_extractor': sub_model}
supported_types = self._feature_extractor.supported_sub_model_types
supported_types += ['fine_tune']
if fine_tune_checkpoint_type not in supported_types:
message = ('Checkpoint type "{}" not supported for {}. '
'Supported types are {}')
raise ValueError(
message.format(fine_tune_checkpoint_type,
self._feature_extractor.__class__.__name__,
supported_types))
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
else:
return {'feature_extractor': self._feature_extractor.get_sub_model(
fine_tune_checkpoint_type)}
def updates(self):
raise RuntimeError('This model is intended to be used with model_lib_v2 '
......
......@@ -1917,8 +1917,9 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
"""Test that restoring unsupported checkpoint type raises an error."""
model = build_center_net_meta_arch(build_resnet=True)
msg = ("Sub model detection is not defined for ResNet."
"Supported types are ['classification'].")
msg = ("Checkpoint type \"detection\" not supported for "
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection')
......
......@@ -62,14 +62,15 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for Hourglass.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_104(channel_means, channel_stds, bgr_ordering):
......
......@@ -101,14 +101,15 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
return 1
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for MobileNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
......
......@@ -123,16 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self):
return 4
@property
def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types)
+ 'Use the script convert_keras_models.py to create your own '
+ 'classification checkpoints.'))
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
......
......@@ -159,16 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self):
return 4
@property
def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet FPN.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types))
+ 'Use the script convert_keras_models.py to create your own '
+ 'classification checkpoints.')
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment