Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
319589aa
Commit
319589aa
authored
May 07, 2021
by
vedanshu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
64f323b1
eaeea071
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
106 additions
and
22 deletions
+106
-22
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
...center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
+24
-0
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+2
-8
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+2
-8
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+7
-1
research/object_detection/protos/train.proto
research/object_detection/protos/train.proto
+19
-5
research/object_detection/utils/variables_helper.py
research/object_detection/utils/variables_helper.py
+52
-0
No files found.
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
View file @
319589aa
...
@@ -103,6 +103,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
...
@@ -103,6 +103,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
# a depth multiplier of 2.
# a depth multiplier of 2.
self
.
assertEqual
(
64
,
first_conv
.
filters
)
self
.
assertEqual
(
64
,
first_conv
.
filters
)
def
test_center_net_mobilenet_v2_fpn_feature_extractor_interpolation
(
self
):
channel_means
=
(
0.
,
0.
,
0.
)
channel_stds
=
(
1.
,
1.
,
1.
)
bgr_ordering
=
False
model
=
(
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
True
,
upsampling_interpolation
=
'bilinear'
))
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
processed_img
=
model
.
preprocess
(
img
)
return
model
(
processed_img
)
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
24
))
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn
=
model
.
get_layer
(
'model_1'
)
for
layer
in
fpn
.
layers
:
if
'up_sampling2d'
in
layer
.
name
:
self
.
assertEqual
(
'bilinear'
,
layer
.
interpolation
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
319589aa
...
@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
return
4
return
4
@
property
@
property
def
supported_sub_model_types
(
self
):
def
classification_backbone
(
self
):
return
[
'classification'
]
return
self
.
_base_model
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
319589aa
...
@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
return
4
return
4
@
property
@
property
def
supported_sub_model_types
(
self
):
def
classification_backbone
(
self
):
return
[
'classification'
]
return
self
.
_base_model
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
...
...
research/object_detection/protos/center_net.proto
View file @
319589aa
...
@@ -440,11 +440,17 @@ message CenterNetFeatureExtractor {
...
@@ -440,11 +440,17 @@ message CenterNetFeatureExtractor {
optional
bool
use_depthwise
=
5
[
default
=
false
];
optional
bool
use_depthwise
=
5
[
default
=
false
];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See
// subclasses of `CenterNetFeatureExtractor`.
optional
float
depth_multiplier
=
9
[
default
=
1.0
];
optional
float
depth_multiplier
=
9
[
default
=
1.0
];
// Whether to use separable convolutions. Only valid for specific
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
// models. See subclasses of `CenterNetFeatureExtractor`.
optional
bool
use_separable_conv
=
10
[
default
=
false
];
optional
bool
use_separable_conv
=
10
[
default
=
false
];
// Which interpolation method to use for the upsampling ops in the FPN.
// Currently only valid for CenterNetMobileNetV2FPNFeatureExtractor. The value
// can be on of 'nearest' or 'bilinear'.
optional
string
upsampling_interpolation
=
11
[
default
=
'nearest'
];
}
}
research/object_detection/protos/train.proto
View file @
319589aa
...
@@ -40,12 +40,26 @@ message TrainConfig {
...
@@ -40,12 +40,26 @@ message TrainConfig {
// extractor variables trained outside of object detection.
// extractor variables trained outside of object detection.
optional
string
fine_tune_checkpoint
=
7
[
default
=
""
];
optional
string
fine_tune_checkpoint
=
7
[
default
=
""
];
// Type of checkpoint to restore variables from, e.g. 'classification'
// This option controls how variables are restored from the (pre-trained)
// 'detection', `fine_tune`, `full`. Controls which variables are restored
// fine_tune_checkpoint. For TF2 models, 3 different types are supported:
// from the pre-trained checkpoint. For meta architecture specific valid
// 1. "classification": Restores only the classification backbone part of
// values of this parameter, see the restore_map (TF1) or
// the feature extractor. This option is typically used when you want
// to train a detection model starting from a pre-trained image
// classification model, e.g. a ResNet model pre-trained on ImageNet.
// 2. "detection": Restores the entire feature extractor. The only parts
// of the full detection model that are not restored are the box and
// class prediction heads. This option is typically used when you want
// to use a pre-trained detection model and train on a new dataset or
// task which requires different box and class prediction heads.
// 3. "full": Restores the entire detection model, including the
// feature extractor, its classification backbone, and the prediction
// heads. This option should only be used when the pre-training and
// fine-tuning tasks are the same. Otherwise, the model's parameters
// may have incompatible shapes, which will cause errors when
// attempting to restore the checkpoint.
// For more details about this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the
// restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files
// /meta_architectures/*meta_arch.py files
.
optional
string
fine_tune_checkpoint_type
=
22
[
default
=
""
];
optional
string
fine_tune_checkpoint_type
=
22
[
default
=
""
];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
...
...
research/object_detection/utils/variables_helper.py
View file @
319589aa
...
@@ -21,6 +21,7 @@ from __future__ import division
...
@@ -21,6 +21,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
logging
import
logging
import
os
import
re
import
re
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
...
@@ -29,6 +30,19 @@ import tf_slim as slim
...
@@ -29,6 +30,19 @@ import tf_slim as slim
from
tensorflow.python.ops
import
variables
as
tf_variables
from
tensorflow.python.ops
import
variables
as
tf_variables
# Maps checkpoint types to variable name prefixes that are no longer
# supported
DETECTION_FEATURE_EXTRACTOR_MSG
=
"""
\
The checkpoint type 'detection' is not supported when it contains variable
names with 'feature_extractor'. Please download the new checkpoint file
from model zoo.
"""
DEPRECATED_CHECKPOINT_MAP
=
{
'detection'
:
(
'feature_extractor'
,
DETECTION_FEATURE_EXTRACTOR_MSG
)
}
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
# tensorflow/contrib/framework/python/ops/variables.py
def
filter_variables
(
variables
,
filter_regex_list
,
invert
=
False
):
def
filter_variables
(
variables
,
filter_regex_list
,
invert
=
False
):
...
@@ -176,3 +190,41 @@ def get_global_variables_safely():
...
@@ -176,3 +190,41 @@ def get_global_variables_safely():
"executing eagerly. Use a Keras model's `.variables` "
"executing eagerly. Use a Keras model's `.variables` "
"attribute instead."
)
"attribute instead."
)
return
tf
.
global_variables
()
return
tf
.
global_variables
()
def
ensure_checkpoint_supported
(
checkpoint_path
,
checkpoint_type
,
model_dir
):
"""Ensures that the given checkpoint can be properly loaded.
Performs the following checks
1. Raises an error if checkpoint_path and model_dir are same.
2. Checks that checkpoint_path does not contain a deprecated checkpoint file
by inspecting its variables.
Args:
checkpoint_path: str, path to checkpoint.
checkpoint_type: str, denotes the type of checkpoint.
model_dir: The model directory to store intermediate training checkpoints.
Raises:
RuntimeError: If
1. We detect an deprecated checkpoint file.
2. model_dir and checkpoint_path are in the same directory.
"""
variables
=
tf
.
train
.
list_variables
(
checkpoint_path
)
if
checkpoint_type
in
DEPRECATED_CHECKPOINT_MAP
:
blocked_prefix
,
msg
=
DEPRECATED_CHECKPOINT_MAP
[
checkpoint_type
]
for
var_name
,
_
in
variables
:
if
var_name
.
startswith
(
blocked_prefix
):
tf
.
logging
.
error
(
'Found variable name - %s with prefix %s'
,
var_name
,
blocked_prefix
)
raise
RuntimeError
(
msg
)
checkpoint_path_dir
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
checkpoint_path
))
model_dir
=
os
.
path
.
abspath
(
model_dir
)
if
model_dir
==
checkpoint_path_dir
:
raise
RuntimeError
(
(
'Checkpoint dir ({}) and model_dir ({}) cannot be same.'
.
format
(
checkpoint_path_dir
,
model_dir
)
+
(
' Please set model_dir to a different path.'
)))
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment