Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fa45b626
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1ac38ec89c5899f44f84e44ee461c714544c6af0"
Commit
fa45b626
authored
Jul 17, 2017
by
Derek Chow
Browse files
Change model.restore_fn to return a variable map instead of init_fn.
parent
a57a00f6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
76 additions
and
100 deletions
+76
-100
object_detection/core/model.py
object_detection/core/model.py
+7
-8
object_detection/exporter_test.py
object_detection/exporter_test.py
+1
-1
object_detection/meta_architectures/BUILD
object_detection/meta_architectures/BUILD
+0
-1
object_detection/meta_architectures/faster_rcnn_meta_arch.py
object_detection/meta_architectures/faster_rcnn_meta_arch.py
+15
-31
object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py
...tion/meta_architectures/faster_rcnn_meta_arch_test_lib.py
+14
-8
object_detection/meta_architectures/ssd_meta_arch.py
object_detection/meta_architectures/ssd_meta_arch.py
+9
-17
object_detection/meta_architectures/ssd_meta_arch_test.py
object_detection/meta_architectures/ssd_meta_arch_test.py
+10
-8
object_detection/models/BUILD
object_detection/models/BUILD
+0
-1
object_detection/models/faster_rcnn_inception_resnet_v2_feature_extractor.py
...dels/faster_rcnn_inception_resnet_v2_feature_extractor.py
+9
-16
object_detection/trainer.py
object_detection/trainer.py
+7
-2
object_detection/trainer_test.py
object_detection/trainer_test.py
+4
-7
No files found.
object_detection/core/model.py
View file @
fa45b626
...
@@ -228,25 +228,24 @@ class DetectionModel(object):
...
@@ -228,25 +228,24 @@ class DetectionModel(object):
fields
.
BoxListFields
.
keypoints
]
=
groundtruth_keypoints_list
fields
.
BoxListFields
.
keypoints
]
=
groundtruth_keypoints_list
@
abstractmethod
@
abstractmethod
def
restore_
fn
(
self
,
checkpoint_path
,
from_detection_checkpoint
=
True
):
def
restore_
map
(
self
,
from_detection_checkpoint
=
True
):
"""Return
callable for loading a foreign checkpoint into tensorflow graph
.
"""Return
s a map of variables to load from a foreign checkpoint
.
Loads variables from a different tensorflow graph (typically feature
Returns a map of variable names to load from a checkpoint to variables in
extractor variables)
. This enables the model to initialize based on weights
the model graph
. This enables the model to initialize based on weights
from
from
another task. For example, the feature extractor variables from a
another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object
classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model
detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of
should have the same parameters as this detection model with exception of
the num_classes parameter.
the num_classes parameter.
Args:
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
classification checkpoint for initialization prior to training.
Returns:
Returns:
a callable which takes a tf.Session as input and loads a checkpoint whe
n
A dict mapping variable names (to load from a checkpoint) to variables i
n
run
.
the model graph
.
"""
"""
pass
pass
object_detection/exporter_test.py
View file @
fa45b626
...
@@ -54,7 +54,7 @@ class FakeModel(model.DetectionModel):
...
@@ -54,7 +54,7 @@ class FakeModel(model.DetectionModel):
np
.
arange
(
32
).
reshape
([
2
,
4
,
4
]),
tf
.
float32
)
np
.
arange
(
32
).
reshape
([
2
,
4
,
4
]),
tf
.
float32
)
return
postprocessed_tensors
return
postprocessed_tensors
def
restore_
fn
(
self
,
checkpoint_path
,
from_detection_checkpoint
):
def
restore_
map
(
self
,
checkpoint_path
,
from_detection_checkpoint
):
pass
pass
def
loss
(
self
,
prediction_dict
):
def
loss
(
self
,
prediction_dict
):
...
...
object_detection/meta_architectures/BUILD
View file @
fa45b626
...
@@ -56,7 +56,6 @@ py_library(
...
@@ -56,7 +56,6 @@ py_library(
"//tensorflow_models/object_detection/core:standard_fields"
,
"//tensorflow_models/object_detection/core:standard_fields"
,
"//tensorflow_models/object_detection/core:target_assigner"
,
"//tensorflow_models/object_detection/core:target_assigner"
,
"//tensorflow_models/object_detection/utils:ops"
,
"//tensorflow_models/object_detection/utils:ops"
,
"//tensorflow_models/object_detection/utils:variables_helper"
,
],
],
)
)
...
...
object_detection/meta_architectures/faster_rcnn_meta_arch.py
View file @
fa45b626
...
@@ -80,7 +80,6 @@ from object_detection.core import post_processing
...
@@ -80,7 +80,6 @@ from object_detection.core import post_processing
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
target_assigner
from
object_detection.core
import
target_assigner
from
object_detection.utils
import
ops
from
object_detection.utils
import
ops
from
object_detection.utils
import
variables_helper
slim
=
tf
.
contrib
.
slim
slim
=
tf
.
contrib
.
slim
...
@@ -159,21 +158,19 @@ class FasterRCNNFeatureExtractor(object):
...
@@ -159,21 +158,19 @@ class FasterRCNNFeatureExtractor(object):
def
restore_from_classification_checkpoint_fn
(
def
restore_from_classification_checkpoint_fn
(
self
,
self
,
checkpoint_path
,
first_stage_feature_extractor_scope
,
first_stage_feature_extractor_scope
,
second_stage_feature_extractor_scope
):
second_stage_feature_extractor_scope
):
"""Returns
callable for loading a checkpoint into the tensorflow graph
.
"""Returns
a map of variables to load from a foreign checkpoint
.
Args:
Args:
checkpoint_path: path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
feature extractor.
Returns:
Returns:
a callable which takes a tf.Session as input and loads a checkpoint whe
n
A dict mapping variable names (to load from a checkpoint) to variables i
n
run
.
the model graph
.
"""
"""
variables_to_restore
=
{}
variables_to_restore
=
{}
for
variable
in
tf
.
global_variables
():
for
variable
in
tf
.
global_variables
():
...
@@ -182,13 +179,7 @@ class FasterRCNNFeatureExtractor(object):
...
@@ -182,13 +179,7 @@ class FasterRCNNFeatureExtractor(object):
if
variable
.
op
.
name
.
startswith
(
scope_name
):
if
variable
.
op
.
name
.
startswith
(
scope_name
):
var_name
=
variable
.
op
.
name
.
replace
(
scope_name
+
'/'
,
''
)
var_name
=
variable
.
op
.
name
.
replace
(
scope_name
+
'/'
,
''
)
variables_to_restore
[
var_name
]
=
variable
variables_to_restore
[
var_name
]
=
variable
variables_to_restore
=
(
return
variables_to_restore
variables_helper
.
get_variables_available_in_checkpoint
(
variables_to_restore
,
checkpoint_path
))
saver
=
tf
.
train
.
Saver
(
variables_to_restore
)
def
restore
(
sess
):
saver
.
restore
(
sess
,
checkpoint_path
)
return
restore
class
FasterRCNNMetaArch
(
model
.
DetectionModel
):
class
FasterRCNNMetaArch
(
model
.
DetectionModel
):
...
@@ -1413,25 +1404,22 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -1413,25 +1404,22 @@ class FasterRCNNMetaArch(model.DetectionModel):
cls_losses
=
tf
.
expand_dims
(
single_image_cls_loss
,
0
),
cls_losses
=
tf
.
expand_dims
(
single_image_cls_loss
,
0
),
decoded_boxlist_list
=
[
proposal_boxlist
])
decoded_boxlist_list
=
[
proposal_boxlist
])
def
restore_fn
(
self
,
checkpoint_path
,
from_detection_checkpoint
=
True
):
def
restore_map
(
self
,
from_detection_checkpoint
=
True
):
"""Returns callable for loading a checkpoint into the tensorflow graph.
"""Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args:
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
from_detection_checkpoint: whether to restore from a detection checkpoint
checkpoint (with compatible variable names) or to restore from a
(with compatible variable names) or to restore from a classification
classification checkpoint for initialization prior to training.
checkpoint for initialization prior to training. Note that when
from_detection_checkpoint=True, the current implementation only
supports restoration from an (exactly) identical model (with exception
of the num_classes parameter).
Returns:
Returns:
a callable which takes a tf.Session as input and loads a checkpoint whe
n
A dict mapping variable names (to load from a checkpoint) to variables i
n
run
.
the model graph
.
"""
"""
if
not
from_detection_checkpoint
:
if
not
from_detection_checkpoint
:
return
self
.
_feature_extractor
.
restore_from_classification_checkpoint_fn
(
return
self
.
_feature_extractor
.
restore_from_classification_checkpoint_fn
(
checkpoint_path
,
self
.
first_stage_feature_extractor_scope
,
self
.
first_stage_feature_extractor_scope
,
self
.
second_stage_feature_extractor_scope
)
self
.
second_stage_feature_extractor_scope
)
...
@@ -1439,13 +1427,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -1439,13 +1427,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
variables_to_restore
.
append
(
slim
.
get_or_create_global_step
())
variables_to_restore
.
append
(
slim
.
get_or_create_global_step
())
# Only load feature extractor variables to be consistent with loading from
# Only load feature extractor variables to be consistent with loading from
# a classification checkpoint.
# a classification checkpoint.
f
irst_stage
_variables
=
tf
.
contrib
.
framework
.
filter_variables
(
f
eature_extractor
_variables
=
tf
.
contrib
.
framework
.
filter_variables
(
variables_to_restore
,
variables_to_restore
,
include_patterns
=
[
self
.
first_stage_feature_extractor_scope
,
include_patterns
=
[
self
.
first_stage_feature_extractor_scope
,
self
.
second_stage_feature_extractor_scope
])
self
.
second_stage_feature_extractor_scope
])
return
{
var
.
op
.
name
:
var
for
var
in
feature_extractor_variables
}
saver
=
tf
.
train
.
Saver
(
first_stage_variables
)
def
restore
(
sess
):
saver
.
restore
(
sess
,
checkpoint_path
)
return
restore
object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py
View file @
fa45b626
...
@@ -957,7 +957,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -957,7 +957,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
exp_loc_loss
)
exp_loc_loss
)
self
.
assertAllClose
(
loss_dict_out
[
'second_stage_classification_loss'
],
0
)
self
.
assertAllClose
(
loss_dict_out
[
'second_stage_classification_loss'
],
0
)
def
test_restore_
fn
_classification
(
self
):
def
test_restore_
map_for
_classification
_ckpt
(
self
):
# Define mock tensorflow classification graph and save variables.
# Define mock tensorflow classification graph and save variables.
test_graph_classification
=
tf
.
Graph
()
test_graph_classification
=
tf
.
Graph
()
with
test_graph_classification
.
as_default
():
with
test_graph_classification
.
as_default
():
...
@@ -986,12 +986,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -986,12 +986,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs
=
model
.
preprocess
(
inputs
)
preprocessed_inputs
=
model
.
preprocess
(
inputs
)
prediction_dict
=
model
.
predict
(
preprocessed_inputs
)
prediction_dict
=
model
.
predict
(
preprocessed_inputs
)
model
.
postprocess
(
prediction_dict
)
model
.
postprocess
(
prediction_dict
)
restore_fn
=
model
.
restore_fn
(
saved_model_path
,
var_map
=
model
.
restore_map
(
from_detection_checkpoint
=
False
)
from_detection_checkpoint
=
False
)
self
.
assertIsInstance
(
var_map
,
dict
)
saver
=
tf
.
train
.
Saver
(
var_map
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
restore_fn
(
sess
)
saver
.
restore
(
sess
,
saved_model_path
)
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
self
.
assertNotIn
(
model
.
first_stage_feature_extractor_scope
,
var
.
name
)
self
.
assertNotIn
(
model
.
second_stage_feature_extractor_scope
,
var
.
name
)
def
test_restore_
fn
_detection
(
self
):
def
test_restore_
map_for
_detection
_ckpt
(
self
):
# Define first detection graph and save variables.
# Define first detection graph and save variables.
test_graph_detection1
=
tf
.
Graph
()
test_graph_detection1
=
tf
.
Graph
()
with
test_graph_detection1
.
as_default
():
with
test_graph_detection1
.
as_default
():
...
@@ -1022,10 +1027,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -1022,10 +1027,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs2
=
model2
.
preprocess
(
inputs2
)
preprocessed_inputs2
=
model2
.
preprocess
(
inputs2
)
prediction_dict2
=
model2
.
predict
(
preprocessed_inputs2
)
prediction_dict2
=
model2
.
predict
(
preprocessed_inputs2
)
model2
.
postprocess
(
prediction_dict2
)
model2
.
postprocess
(
prediction_dict2
)
restore_fn
=
model2
.
restore_fn
(
saved_model_path
,
var_map
=
model2
.
restore_map
(
from_detection_checkpoint
=
True
)
from_detection_checkpoint
=
True
)
self
.
assertIsInstance
(
var_map
,
dict
)
saver
=
tf
.
train
.
Saver
(
var_map
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
restore
_fn
(
sess
)
saver
.
restore
(
sess
,
saved_model_path
)
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
self
.
assertNotIn
(
model2
.
first_stage_feature_extractor_scope
,
var
.
name
)
self
.
assertNotIn
(
model2
.
first_stage_feature_extractor_scope
,
var
.
name
)
self
.
assertNotIn
(
model2
.
second_stage_feature_extractor_scope
,
self
.
assertNotIn
(
model2
.
second_stage_feature_extractor_scope
,
...
...
object_detection/meta_architectures/ssd_meta_arch.py
View file @
fa45b626
...
@@ -29,7 +29,6 @@ from object_detection.core import box_predictor as bpredictor
...
@@ -29,7 +29,6 @@ from object_detection.core import box_predictor as bpredictor
from
object_detection.core
import
model
from
object_detection.core
import
model
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
target_assigner
from
object_detection.core
import
target_assigner
from
object_detection.utils
import
variables_helper
slim
=
tf
.
contrib
.
slim
slim
=
tf
.
contrib
.
slim
...
@@ -562,33 +561,26 @@ class SSDMetaArch(model.DetectionModel):
...
@@ -562,33 +561,26 @@ class SSDMetaArch(model.DetectionModel):
decoded_boxlist_list
=
decoded_boxlist_list
,
decoded_boxlist_list
=
decoded_boxlist_list
,
match_list
=
match_list
)
match_list
=
match_list
)
def
restore_fn
(
self
,
checkpoint_path
,
from_detection_checkpoint
=
True
):
def
restore_map
(
self
,
from_detection_checkpoint
=
True
):
"""Return callable for loading a checkpoint into the tensorflow graph.
"""Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args:
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
classification checkpoint for initialization prior to training.
Returns:
Returns:
a callable which takes a tf.Session as input and loads a checkpoint whe
n
A dict mapping variable names (to load from a checkpoint) to variables i
n
run
.
the model graph
.
"""
"""
variables_to_restore
=
{}
variables_to_restore
=
{}
for
variable
in
tf
.
all_variables
():
for
variable
in
tf
.
all_variables
():
if
variable
.
op
.
name
.
startswith
(
self
.
_extract_features_scope
):
if
variable
.
op
.
name
.
startswith
(
self
.
_extract_features_scope
):
var_name
=
variable
.
op
.
name
var_name
=
variable
.
op
.
name
if
not
from_detection_checkpoint
:
if
not
from_detection_checkpoint
:
var_name
=
(
var_name
=
(
re
.
split
(
'^'
+
self
.
_extract_features_scope
+
'/'
,
re
.
split
(
'^'
+
self
.
_extract_features_scope
+
'/'
,
var_name
)[
-
1
])
var_name
)[
-
1
])
variables_to_restore
[
var_name
]
=
variable
variables_to_restore
[
var_name
]
=
variable
# TODO: Load variables selectively using scopes.
return
variables_to_restore
variables_to_restore
=
(
variables_helper
.
get_variables_available_in_checkpoint
(
variables_to_restore
,
checkpoint_path
))
saver
=
tf
.
train
.
Saver
(
variables_to_restore
)
def
restore
(
sess
):
saver
.
restore
(
sess
,
checkpoint_path
)
return
restore
object_detection/meta_architectures/ssd_meta_arch_test.py
View file @
fa45b626
...
@@ -207,20 +207,21 @@ class SsdMetaArchTest(tf.test.TestCase):
...
@@ -207,20 +207,21 @@ class SsdMetaArchTest(tf.test.TestCase):
self
.
assertAllClose
(
losses_out
[
'classification_loss'
],
self
.
assertAllClose
(
losses_out
[
'classification_loss'
],
expected_classification_loss
)
expected_classification_loss
)
def
test_restore_
fn
_detection
(
self
):
def
test_restore_
map_for
_detection
_ckpt
(
self
):
init_op
=
tf
.
global_variables_initializer
()
init_op
=
tf
.
global_variables_initializer
()
saver
=
tf_saver
.
Saver
()
saver
=
tf_saver
.
Saver
()
save_path
=
self
.
get_temp_dir
()
save_path
=
self
.
get_temp_dir
()
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
sess
.
run
(
init_op
)
saved_model_path
=
saver
.
save
(
sess
,
save_path
)
saved_model_path
=
saver
.
save
(
sess
,
save_path
)
restore_fn
=
self
.
_model
.
restore_fn
(
saved_model_path
,
var_map
=
self
.
_model
.
restore_map
(
from_detection_checkpoint
=
True
)
from_detection_checkpoint
=
True
)
self
.
assertIsInstance
(
var_map
,
dict
)
restore_fn
(
sess
)
saver
=
tf
.
train
.
Saver
(
var_map
)
saver
.
restore
(
sess
,
saved_model_path
)
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
self
.
assertNotIn
(
'FeatureExtractor'
,
var
.
name
)
self
.
assertNotIn
(
'FeatureExtractor'
,
var
.
name
)
def
test_restore_
fn
_classification
(
self
):
def
test_restore_
map_for
_classification
_ckpt
(
self
):
# Define mock tensorflow classification graph and save variables.
# Define mock tensorflow classification graph and save variables.
test_graph_classification
=
tf
.
Graph
()
test_graph_classification
=
tf
.
Graph
()
with
test_graph_classification
.
as_default
():
with
test_graph_classification
.
as_default
():
...
@@ -246,10 +247,11 @@ class SsdMetaArchTest(tf.test.TestCase):
...
@@ -246,10 +247,11 @@ class SsdMetaArchTest(tf.test.TestCase):
preprocessed_inputs
=
self
.
_model
.
preprocess
(
inputs
)
preprocessed_inputs
=
self
.
_model
.
preprocess
(
inputs
)
prediction_dict
=
self
.
_model
.
predict
(
preprocessed_inputs
)
prediction_dict
=
self
.
_model
.
predict
(
preprocessed_inputs
)
self
.
_model
.
postprocess
(
prediction_dict
)
self
.
_model
.
postprocess
(
prediction_dict
)
restore_fn
=
self
.
_model
.
restore_fn
(
saved_model_path
,
var_map
=
self
.
_model
.
restore_map
(
from_detection_checkpoint
=
False
)
from_detection_checkpoint
=
False
)
self
.
assertIsInstance
(
var_map
,
dict
)
saver
=
tf
.
train
.
Saver
(
var_map
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
restore
_fn
(
sess
)
saver
.
restore
(
sess
,
saved_model_path
)
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
for
var
in
sess
.
run
(
tf
.
report_uninitialized_variables
()):
self
.
assertNotIn
(
'FeatureExtractor'
,
var
.
name
)
self
.
assertNotIn
(
'FeatureExtractor'
,
var
.
name
)
...
...
object_detection/models/BUILD
View file @
fa45b626
...
@@ -94,7 +94,6 @@ py_library(
...
@@ -94,7 +94,6 @@ py_library(
deps
=
[
deps
=
[
"//tensorflow"
,
"//tensorflow"
,
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch"
,
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch"
,
"//tensorflow_models/object_detection/utils:variables_helper"
,
"//tensorflow_models/slim:inception_resnet_v2"
,
"//tensorflow_models/slim:inception_resnet_v2"
,
],
],
)
)
...
...
object_detection/models/faster_rcnn_inception_resnet_v2_feature_extractor.py
View file @
fa45b626
...
@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012)
...
@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012)
import
tensorflow
as
tf
import
tensorflow
as
tf
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
object_detection.utils
import
variables_helper
from
nets
import
inception_resnet_v2
from
nets
import
inception_resnet_v2
slim
=
tf
.
contrib
.
slim
slim
=
tf
.
contrib
.
slim
...
@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
...
@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
def
restore_from_classification_checkpoint_fn
(
def
restore_from_classification_checkpoint_fn
(
self
,
self
,
checkpoint_path
,
first_stage_feature_extractor_scope
,
first_stage_feature_extractor_scope
,
second_stage_feature_extractor_scope
):
second_stage_feature_extractor_scope
):
"""Returns
callable for loading a checkpoint into the tensorflow graph
.
"""Returns
a map of variables to load from a foreign checkpoint
.
Note that this overrides the default implementation in
Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
InceptionResnetV2 checkpoints.
InceptionResnetV2 checkpoints.
TODO: revisit whether it's possible to force the `Repeat` namescope as
TODO(jonathanhuang,rathodv): revisit whether it's possible to force the
created in `_extract_box_classifier_features` to start counting at 2 (e.g.
`Repeat` namescope as created in `_extract_box_classifier_features` to
`Repeat_2`) so that the default restore_fn can be used.
start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can
be used.
Args:
Args:
checkpoint_path: Path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
feature extractor.
Returns:
Returns:
a callable which takes a tf.Session as input and loads a checkpoint whe
n
A dict mapping variable names (to load from a checkpoint) to variables i
n
run
.
the model graph
.
"""
"""
variables_to_restore
=
{}
variables_to_restore
=
{}
for
variable
in
tf
.
global_variables
():
for
variable
in
tf
.
global_variables
():
if
variable
.
op
.
name
.
startswith
(
if
variable
.
op
.
name
.
startswith
(
...
@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
...
@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
var_name
=
var_name
.
replace
(
var_name
=
var_name
.
replace
(
second_stage_feature_extractor_scope
+
'/'
,
''
)
second_stage_feature_extractor_scope
+
'/'
,
''
)
variables_to_restore
[
var_name
]
=
variable
variables_to_restore
[
var_name
]
=
variable
variables_to_restore
=
(
return
variables_to_restore
variables_helper
.
get_variables_available_in_checkpoint
(
variables_to_restore
,
checkpoint_path
))
saver
=
tf
.
train
.
Saver
(
variables_to_restore
)
def
restore
(
sess
):
saver
.
restore
(
sess
,
checkpoint_path
)
return
restore
object_detection/trainer.py
View file @
fa45b626
...
@@ -211,9 +211,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
...
@@ -211,9 +211,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
# Create ops required to initialize the model from a given checkpoint.
# Create ops required to initialize the model from a given checkpoint.
init_fn
=
None
init_fn
=
None
if
train_config
.
fine_tune_checkpoint
:
if
train_config
.
fine_tune_checkpoint
:
init_fn
=
detection_model
.
restore_fn
(
var_map
=
detection_model
.
restore_map
(
train_config
.
fine_tune_checkpoint
,
from_detection_checkpoint
=
train_config
.
from_detection_checkpoint
)
from_detection_checkpoint
=
train_config
.
from_detection_checkpoint
)
var_map
=
variables_helper
.
get_variables_available_in_checkpoint
(
var_map
,
train_config
.
fine_tune_checkpoint
)
saver
=
tf
.
train
.
Saver
(
var_map
)
def
initializer_fn
(
sess
):
saver
.
restore
(
sess
,
train_config
.
fine_tune_checkpoint
)
init_fn
=
initializer_fn
with
tf
.
device
(
deploy_config
.
optimizer_device
()):
with
tf
.
device
(
deploy_config
.
optimizer_device
()):
total_loss
,
grads_and_vars
=
model_deploy
.
optimize_clones
(
total_loss
,
grads_and_vars
=
model_deploy
.
optimize_clones
(
...
...
object_detection/trainer_test.py
View file @
fa45b626
...
@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel):
...
@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel):
}
}
return
loss_dict
return
loss_dict
def
restore_
fn
(
self
,
checkpoint_path
,
from_detection_checkpoint
=
True
):
def
restore_
map
(
self
,
from_detection_checkpoint
=
True
):
"""Return
callable for loading a checkpoint into the tensorflow graph
.
"""Return
s a map of variables to load from a foreign checkpoint
.
Args:
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
classification checkpoint for initialization prior to training.
Returns:
Returns:
a callable which takes a tf.Session and does nothing
.
A dict mapping variable names to variables
.
"""
"""
def
restore
(
unused_sess
):
return
{
var
.
op
.
name
:
var
for
var
in
tf
.
global_variables
()}
return
return
restore
class
TrainerTest
(
tf
.
test
.
TestCase
):
class
TrainerTest
(
tf
.
test
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment