"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "a25e8e42eb657c6a34ae67fc1fe69a19a5bef1be"
Commit 07484704 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Better error message when loading a wrong checkpoint type in CenterNet.

PiperOrigin-RevId: 322967458
parent bdaa525b
...@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
pass pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
bias_fill=None): bias_fill=None):
...@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint). A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
""" """
if fine_tune_checkpoint_type == 'classification': sub_model = self._feature_extractor.get_sub_model(fine_tune_checkpoint_type)
return {'feature_extractor': self._feature_extractor.get_base_model()} return {'feature_extractor': sub_model}
elif fine_tune_checkpoint_type == 'detection':
return {'feature_extractor': self._feature_extractor.get_model()}
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
else:
raise ValueError('Not supported fine tune checkpoint type - {}'.format(
fine_tune_checkpoint_type))
def updates(self): def updates(self):
raise RuntimeError('This model is intended to be used with model_lib_v2 ' raise RuntimeError('This model is intended to be used with model_lib_v2 '
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
from __future__ import division from __future__ import division
import functools import functools
import re
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase): ...@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'], self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model) tf.keras.Model)
def test_retore_map_error(self):
"""Test that restoring unsupported checkpoint type raises an error."""
model = build_center_net_meta_arch(build_resnet=True)
msg = ("Sub model detection is not defined for ResNet."
"Supported types are ['classification'].")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection')
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor): class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
...@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor( ...@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses return self._network.num_hourglasses
def get_model(self): def get_sub_model(self, sub_model_type):
return self._network if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for Hourglass.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def hourglass_104(channel_means, channel_stds, bgr_ordering): def hourglass_104(channel_means, channel_stds, bgr_ordering):
......
...@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor.""" """The number of feature outputs returned by the feature extractor."""
return 1 return 1
def get_model(self): def get_sub_model(self, sub_model_type):
return self._network if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for MobileNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering): def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
......
...@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path) self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs): def call(self, inputs):
"""Returns image features extracted by the backbone. """Returns image features extracted by the backbone.
...@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self): def out_stride(self):
return 4 return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering): def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
"""The ResNet v2 101 feature extractor.""" """The ResNet v2 101 feature extractor."""
......
...@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path) self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs): def call(self, inputs):
"""Returns image features extracted by the backbone. """Returns image features extracted by the backbone.
...@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self): def out_stride(self):
return 4 return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet FPN.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types)))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
"""The ResNet v1 101 FPN feature extractor.""" """The ResNet v1 101 FPN feature extractor."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment