Commit 319589aa authored by vedanshu's avatar vedanshu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 64f323b1 eaeea071
...@@ -103,6 +103,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -103,6 +103,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
# a depth multiplier of 2. # a depth multiplier of 2.
self.assertEqual(64, first_conv.filters) self.assertEqual(64, first_conv.filters)
def test_center_net_mobilenet_v2_fpn_feature_extractor_interpolation(self):
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True,
upsampling_interpolation='bilinear'))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn = model.get_layer('model_1')
for layer in fpn.layers:
if 'up_sampling2d' in layer.name:
self.assertEqual('bilinear', layer.interpolation)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
return 4 return 4
@property @property
def supported_sub_model_types(self): def classification_backbone(self):
return ['classification'] return self._base_model
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs): def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
......
...@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
return 4 return 4
@property @property
def supported_sub_model_types(self): def classification_backbone(self):
return ['classification'] return self._base_model
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
......
...@@ -440,11 +440,17 @@ message CenterNetFeatureExtractor { ...@@ -440,11 +440,17 @@ message CenterNetFeatureExtractor {
optional bool use_depthwise = 5 [default = false]; optional bool use_depthwise = 5 [default = false];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`. // Depth multiplier. Only valid for specific models (e.g. MobileNet). See
// subclasses of `CenterNetFeatureExtractor`.
optional float depth_multiplier = 9 [default = 1.0]; optional float depth_multiplier = 9 [default = 1.0];
// Whether to use separable convolutions. Only valid for specific // Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`. // models. See subclasses of `CenterNetFeatureExtractor`.
optional bool use_separable_conv = 10 [default = false]; optional bool use_separable_conv = 10 [default = false];
// Which interpolation method to use for the upsampling ops in the FPN.
// Currently only valid for CenterNetMobileNetV2FPNFeatureExtractor. The value
// can be on of 'nearest' or 'bilinear'.
optional string upsampling_interpolation = 11 [default = 'nearest'];
} }
...@@ -40,12 +40,26 @@ message TrainConfig { ...@@ -40,12 +40,26 @@ message TrainConfig {
// extractor variables trained outside of object detection. // extractor variables trained outside of object detection.
optional string fine_tune_checkpoint = 7 [default=""]; optional string fine_tune_checkpoint = 7 [default=""];
// Type of checkpoint to restore variables from, e.g. 'classification' // This option controls how variables are restored from the (pre-trained)
// 'detection', `fine_tune`, `full`. Controls which variables are restored // fine_tune_checkpoint. For TF2 models, 3 different types are supported:
// from the pre-trained checkpoint. For meta architecture specific valid // 1. "classification": Restores only the classification backbone part of
// values of this parameter, see the restore_map (TF1) or // the feature extractor. This option is typically used when you want
// to train a detection model starting from a pre-trained image
// classification model, e.g. a ResNet model pre-trained on ImageNet.
// 2. "detection": Restores the entire feature extractor. The only parts
// of the full detection model that are not restored are the box and
// class prediction heads. This option is typically used when you want
// to use a pre-trained detection model and train on a new dataset or
// task which requires different box and class prediction heads.
// 3. "full": Restores the entire detection model, including the
// feature extractor, its classification backbone, and the prediction
// heads. This option should only be used when the pre-training and
// fine-tuning tasks are the same. Otherwise, the model's parameters
// may have incompatible shapes, which will cause errors when
// attempting to restore the checkpoint.
// For more details about this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the // restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files // /meta_architectures/*meta_arch.py files.
optional string fine_tune_checkpoint_type = 22 [default=""]; optional string fine_tune_checkpoint_type = 22 [default=""];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow // Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
......
...@@ -21,6 +21,7 @@ from __future__ import division ...@@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import logging import logging
import os
import re import re
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -29,6 +30,19 @@ import tf_slim as slim ...@@ -29,6 +30,19 @@ import tf_slim as slim
from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variables as tf_variables
# Maps checkpoint types to variable name prefixes that are no longer
# supported
DETECTION_FEATURE_EXTRACTOR_MSG = """\
The checkpoint type 'detection' is not supported when it contains variable
names with 'feature_extractor'. Please download the new checkpoint file
from model zoo.
"""
DEPRECATED_CHECKPOINT_MAP = {
'detection': ('feature_extractor', DETECTION_FEATURE_EXTRACTOR_MSG)
}
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in # TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py # tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False): def filter_variables(variables, filter_regex_list, invert=False):
...@@ -176,3 +190,41 @@ def get_global_variables_safely(): ...@@ -176,3 +190,41 @@ def get_global_variables_safely():
"executing eagerly. Use a Keras model's `.variables` " "executing eagerly. Use a Keras model's `.variables` "
"attribute instead.") "attribute instead.")
return tf.global_variables() return tf.global_variables()
def ensure_checkpoint_supported(checkpoint_path, checkpoint_type, model_dir):
"""Ensures that the given checkpoint can be properly loaded.
Performs the following checks
1. Raises an error if checkpoint_path and model_dir are same.
2. Checks that checkpoint_path does not contain a deprecated checkpoint file
by inspecting its variables.
Args:
checkpoint_path: str, path to checkpoint.
checkpoint_type: str, denotes the type of checkpoint.
model_dir: The model directory to store intermediate training checkpoints.
Raises:
RuntimeError: If
1. We detect an deprecated checkpoint file.
2. model_dir and checkpoint_path are in the same directory.
"""
variables = tf.train.list_variables(checkpoint_path)
if checkpoint_type in DEPRECATED_CHECKPOINT_MAP:
blocked_prefix, msg = DEPRECATED_CHECKPOINT_MAP[checkpoint_type]
for var_name, _ in variables:
if var_name.startswith(blocked_prefix):
tf.logging.error('Found variable name - %s with prefix %s', var_name,
blocked_prefix)
raise RuntimeError(msg)
checkpoint_path_dir = os.path.abspath(os.path.dirname(checkpoint_path))
model_dir = os.path.abspath(model_dir)
if model_dir == checkpoint_path_dir:
raise RuntimeError(
('Checkpoint dir ({}) and model_dir ({}) cannot be same.'.format(
checkpoint_path_dir, model_dir) +
(' Please set model_dir to a different path.')))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment