Commit fd6987fa authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Support fine_tune in all CenterNet feature extractors.

PiperOrigin-RevId: 326528933
parent f41f14e6
...@@ -118,6 +118,12 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -118,6 +118,12 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
pass pass
@property
@abc.abstractmethod
def supported_sub_model_types(self):
"""Valid sub model types supported by the get_sub_model function."""
pass
@abc.abstractmethod @abc.abstractmethod
def get_sub_model(self, sub_model_type): def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type. """Returns the underlying keras model for the given sub_model_type.
...@@ -2974,22 +2980,47 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2974,22 +2980,47 @@ class CenterNetMetaArch(model.DetectionModel):
fine_tune_checkpoint_type: whether to restore from a full detection fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'. Valid values: `detection`, `classification`, `fine_tune`.
'detection': used when loading in the Hourglass model pre-trained on Default 'detection'.
other detection task. 'detection': used when loading models pre-trained on other detection
'classification': used when loading in the ResNet model pre-trained on tasks. With this checkpoint type the weights of the feature extractor
image classification task. Note that only the image feature encoding are expected under the attribute 'feature_extractor'.
part is loaded but not those upsampling layers. 'classification': used when loading models pre-trained on an image
classification task. Note that only the encoder section of the network
is loaded and not the upsampling layers. With this checkpoint type,
the weights of only the encoder section are expected under the
attribute 'feature_extractor'.
'fine_tune': used when loading the entire CenterNet feature extractor 'fine_tune': used when loading the entire CenterNet feature extractor
pre-trained on other tasks. The checkpoints saved during CenterNet pre-trained on other tasks. The checkpoints saved during CenterNet
model training can be directly loaded using this mode. model training can be directly loaded using this type. With this
checkpoint type, the weights of the feature extractor are expected
under the attribute 'model._feature_extractor'.
For more details, see the tensorflow section on Loading mechanics.
https://www.tensorflow.org/guide/checkpoint#loading_mechanics
Returns: Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint). A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
""" """
sub_model = self._feature_extractor.get_sub_model(fine_tune_checkpoint_type) supported_types = self._feature_extractor.supported_sub_model_types
return {'feature_extractor': sub_model} supported_types += ['fine_tune']
if fine_tune_checkpoint_type not in supported_types:
message = ('Checkpoint type "{}" not supported for {}. '
'Supported types are {}')
raise ValueError(
message.format(fine_tune_checkpoint_type,
self._feature_extractor.__class__.__name__,
supported_types))
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
else:
return {'feature_extractor': self._feature_extractor.get_sub_model(
fine_tune_checkpoint_type)}
def updates(self): def updates(self):
raise RuntimeError('This model is intended to be used with model_lib_v2 ' raise RuntimeError('This model is intended to be used with model_lib_v2 '
......
...@@ -1917,8 +1917,9 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase): ...@@ -1917,8 +1917,9 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
"""Test that restoring unsupported checkpoint type raises an error.""" """Test that restoring unsupported checkpoint type raises an error."""
model = build_center_net_meta_arch(build_resnet=True) model = build_center_net_meta_arch(build_resnet=True)
msg = ("Sub model detection is not defined for ResNet." msg = ("Checkpoint type \"detection\" not supported for "
"Supported types are ['classification'].") "CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']")
with self.assertRaisesRegex(ValueError, re.escape(msg)): with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection') model.restore_from_objects('detection')
......
...@@ -62,14 +62,15 @@ class CenterNetHourglassFeatureExtractor( ...@@ -62,14 +62,15 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses return self._network.num_hourglasses
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type): def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection': if sub_model_type == 'detection':
return self._network return self._network
else: else:
supported_types = ['detection'] ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
raise ValueError(
('Sub model {} is not defined for Hourglass.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def hourglass_104(channel_means, channel_stds, bgr_ordering): def hourglass_104(channel_means, channel_stds, bgr_ordering):
......
...@@ -101,14 +101,15 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -101,14 +101,15 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor.""" """The number of feature outputs returned by the feature extractor."""
return 1 return 1
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type): def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection': if sub_model_type == 'detection':
return self._network return self._network
else: else:
supported_types = ['detection'] ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
raise ValueError(
('Sub model {} is not defined for MobileNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering): def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
......
...@@ -123,16 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -123,16 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self): def out_stride(self):
return 4 return 4
@property
def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type): def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification': if sub_model_type == 'classification':
return self._base_model return self._base_model
else: else:
supported_types = ['classification'] ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
raise ValueError(
('Sub model {} is not defined for ResNet.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types)
+ 'Use the script convert_keras_models.py to create your own '
+ 'classification checkpoints.'))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering): def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
......
...@@ -159,16 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -159,16 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self): def out_stride(self):
return 4 return 4
@property
def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type): def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification': if sub_model_type == 'classification':
return self._base_model return self._base_model
else: else:
supported_types = ['classification'] ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
raise ValueError(
('Sub model {} is not defined for ResNet FPN.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types))
+ 'Use the script convert_keras_models.py to create your own '
+ 'classification checkpoints.')
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment