Commit 07484704 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Better error message when loading a wrong checkpoint type in CenterNet.

PiperOrigin-RevId: 322967458
parent bdaa525b
......@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor."""
pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
bias_fill=None):
......@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
if fine_tune_checkpoint_type == 'classification':
return {'feature_extractor': self._feature_extractor.get_base_model()}
elif fine_tune_checkpoint_type == 'detection':
return {'feature_extractor': self._feature_extractor.get_model()}
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
else:
raise ValueError('Not supported fine tune checkpoint type - {}'.format(
fine_tune_checkpoint_type))
sub_model = self._feature_extractor.get_sub_model(fine_tune_checkpoint_type)
return {'feature_extractor': sub_model}
def updates(self):
raise RuntimeError('This model is intended to be used with model_lib_v2 '
......
......@@ -17,7 +17,9 @@
from __future__ import division
import functools
import re
import unittest
from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
......@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model)
def test_retore_map_error(self):
"""Test that restoring unsupported checkpoint type raises an error."""
model = build_center_net_meta_arch(build_resnet=True)
msg = ("Sub model detection is not defined for ResNet."
"Supported types are ['classification'].")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection')
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
......@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses
def get_model(self):
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for Hourglass.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def hourglass_104(channel_means, channel_stds, bgr_ordering):
......
......@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
return 1
def get_model(self):
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for MobileNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
......
......@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs):
"""Returns image features extracted by the backbone.
......@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self):
return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
"""The ResNet v2 101 feature extractor."""
......
......@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs):
"""Returns image features extracted by the backbone.
......@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self):
return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet FPN.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types)))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
"""The ResNet v1 101 FPN feature extractor."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment