Commit 12970cb3 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

revert faster rcnn

parent eb75e684
...@@ -261,31 +261,6 @@ class FasterRCNNKerasFeatureExtractor(object): ...@@ -261,31 +261,6 @@ class FasterRCNNKerasFeatureExtractor(object):
"""Get model that extracts second stage box classifier features.""" """Get model that extracts second stage box classifier features."""
pass pass
def restore_from_classification_checkpoint_fn(
self,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns a map of variables to load from a foreign checkpoint.
Args:
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
variables_to_restore = {}
for variable in variables_helper.get_global_variables_safely():
for scope_name in [first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope]:
if variable.op.name.startswith(scope_name):
var_name = variable.op.name.replace(scope_name + '/', '')
variables_to_restore[var_name] = variable
return variables_to_restore
class FasterRCNNMetaArch(model.DetectionModel): class FasterRCNNMetaArch(model.DetectionModel):
"""Faster R-CNN Meta-architecture definition.""" """Faster R-CNN Meta-architecture definition."""
...@@ -2801,6 +2776,46 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -2801,6 +2776,46 @@ class FasterRCNNMetaArch(model.DetectionModel):
variables_to_restore, include_patterns=include_patterns) variables_to_restore, include_patterns=include_patterns)
return {var.op.name: var for var in feature_extractor_variables} return {var.op.name: var for var in feature_extractor_variables}
def restore_from_objects(self, fine_tune_checkpoint_type='detection'):
"""Returns a map of Trackable objects to load from a foreign checkpoint.
Returns a dictionary of Tensorflow 2 Trackable objects (e.g. tf.Module
or Checkpoint). This enables the model to initialize based on weights from
another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of
the num_classes parameter.
Note that this function is intended to be used to restore Keras-based
models when running Tensorflow 2, whereas restore_map (above) is intended
to be used to restore Slim-based models when running Tensorflow 1.x.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
if fine_tune_checkpoint_type == 'classification':
return {
'feature_extractor':
self._feature_extractor.classification_backbone
}
elif fine_tune_checkpoint_type == 'detection':
fake_model = tf.train.Checkpoint(
_feature_extractor_for_box_classifier_features=
self._feature_extractor_for_box_classifier_features,
_feature_extractor_for_proposal_features=
self._feature_extractor_for_proposal_features)
return {'model': fake_model}
else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))
def updates(self): def updates(self):
"""Returns a list of update operators for this model. """Returns a list of update operators for this model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment