Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7359586f
Commit
7359586f
authored
Jul 24, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into detr-push-3
parents
c594cecf
a78b05b9
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
76 additions
and
30 deletions
+76
-30
research/object_detection/meta_architectures/context_rcnn_meta_arch.py
...ct_detection/meta_architectures/context_rcnn_meta_arch.py
+14
-5
research/object_detection/meta_architectures/context_rcnn_meta_arch_test.py
...tection/meta_architectures/context_rcnn_meta_arch_test.py
+6
-7
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+14
-5
research/object_detection/model_main_tf2.py
research/object_detection/model_main_tf2.py
+8
-1
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+8
-2
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+8
-2
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+9
-4
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+9
-4
No files found.
research/object_detection/meta_architectures/context_rcnn_meta_arch.py
View file @
7359586f
...
...
@@ -27,7 +27,9 @@ import functools
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.meta_architectures
import
context_rcnn_lib
from
object_detection.meta_architectures
import
context_rcnn_lib_tf2
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
object_detection.utils
import
tf_version
class
ContextRCNNMetaArch
(
faster_rcnn_meta_arch
.
FasterRCNNMetaArch
):
...
...
@@ -264,11 +266,17 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
return_raw_detections_during_predict
),
output_final_box_features
=
output_final_box_features
)
self
.
_context_feature_extract_fn
=
functools
.
partial
(
context_rcnn_lib
.
compute_box_context_attention
,
bottleneck_dimension
=
attention_bottleneck_dimension
,
attention_temperature
=
attention_temperature
,
is_training
=
is_training
)
if
tf_version
.
is_tf1
():
self
.
_context_feature_extract_fn
=
functools
.
partial
(
context_rcnn_lib
.
compute_box_context_attention
,
bottleneck_dimension
=
attention_bottleneck_dimension
,
attention_temperature
=
attention_temperature
,
is_training
=
is_training
)
else
:
self
.
_context_feature_extract_fn
=
context_rcnn_lib_tf2
.
AttentionBlock
(
bottleneck_dimension
=
attention_bottleneck_dimension
,
attention_temperature
=
attention_temperature
,
is_training
=
is_training
)
@
staticmethod
def
get_side_inputs
(
features
):
...
...
@@ -323,6 +331,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
Returns:
A float32 Tensor with shape [K, new_height, new_width, depth].
"""
box_features
=
self
.
_crop_and_resize_fn
(
[
features_to_crop
],
proposal_boxes_normalized
,
None
,
[
self
.
_initial_crop_size
,
self
.
_initial_crop_size
])
...
...
research/object_detection/meta_architectures/context_rcnn_meta_arch_
tf1_
test.py
→
research/object_detection/meta_architectures/context_rcnn_meta_arch_test.py
View file @
7359586f
...
...
@@ -109,7 +109,6 @@ class FakeFasterRCNNKerasFeatureExtractor(
])
@
unittest
.
skipIf
(
tf_version
.
is_tf2
(),
'Skipping TF1.X only test.'
)
class
ContextRCNNMetaArchTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
def
_get_model
(
self
,
box_predictor
,
**
common_kwargs
):
...
...
@@ -440,15 +439,16 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
masks_are_class_agnostic
=
masks_are_class_agnostic
,
share_box_across_classes
=
share_box_across_classes
),
**
common_kwargs
)
@
unittest
.
skipIf
(
tf_version
.
is_tf2
(),
'Skipping TF1.X only test.'
)
@
mock
.
patch
.
object
(
context_rcnn_meta_arch
,
'context_rcnn_lib'
)
def
test_prediction_mock
(
self
,
mock_context_rcnn_lib
):
"""Mocks the context_rcnn_lib module to test the prediction.
def
test_prediction_mock
_tf1
(
self
,
mock_context_rcnn_lib
_v1
):
"""Mocks the context_rcnn_lib
_v1
module to test the prediction.
Using mock object so that we can ensure compute_box_context_attention is
called in side the prediction function.
Args:
mock_context_rcnn_lib: mock module for the context_rcnn_lib.
mock_context_rcnn_lib
_v1
: mock module for the context_rcnn_lib
_v1
.
"""
model
=
self
.
_build_model
(
is_training
=
False
,
...
...
@@ -457,7 +457,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
num_classes
=
42
)
mock_tensor
=
tf
.
ones
([
2
,
8
,
3
,
3
,
3
],
tf
.
float32
)
mock_context_rcnn_lib
.
compute_box_context_attention
.
return_value
=
mock_tensor
mock_context_rcnn_lib
_v1
.
compute_box_context_attention
.
return_value
=
mock_tensor
inputs_shape
=
(
2
,
20
,
20
,
3
)
inputs
=
tf
.
cast
(
tf
.
random_uniform
(
inputs_shape
,
minval
=
0
,
maxval
=
255
,
dtype
=
tf
.
int32
),
...
...
@@ -479,7 +479,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
side_inputs
=
model
.
get_side_inputs
(
features
)
_
=
model
.
predict
(
preprocessed_inputs
,
true_image_shapes
,
**
side_inputs
)
mock_context_rcnn_lib
.
compute_box_context_attention
.
assert_called_once
()
mock_context_rcnn_lib
_v1
.
compute_box_context_attention
.
assert_called_once
()
@
parameterized
.
named_parameters
(
{
'testcase_name'
:
'static_shapes'
,
'static_shapes'
:
True
},
...
...
@@ -518,7 +518,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
}
side_inputs
=
model
.
get_side_inputs
(
features
)
prediction_dict
=
model
.
predict
(
preprocessed_inputs
,
true_image_shapes
,
**
side_inputs
)
return
(
prediction_dict
[
'rpn_box_predictor_features'
],
...
...
research/object_detection/model_lib_v2.py
View file @
7359586f
...
...
@@ -23,6 +23,7 @@ import os
import
time
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v2
as
tf2
from
object_detection
import
eval_util
from
object_detection
import
inputs
...
...
@@ -117,7 +118,8 @@ def _compute_losses_and_predictions_dicts(
prediction_dict
=
model
.
predict
(
preprocessed_images
,
features
[
fields
.
InputDataFields
.
true_image_shape
])
features
[
fields
.
InputDataFields
.
true_image_shape
],
**
model
.
get_side_inputs
(
features
))
prediction_dict
=
ops
.
bfloat16_to_float32_nested
(
prediction_dict
)
losses_dict
=
model
.
loss
(
...
...
@@ -413,8 +415,9 @@ def train_loop(
train_steps
=
None
,
use_tpu
=
False
,
save_final_config
=
False
,
checkpoint_every_n
=
5
000
,
checkpoint_every_n
=
1
000
,
checkpoint_max_to_keep
=
7
,
record_summaries
=
True
,
**
kwargs
):
"""Trains a model using eager + functions.
...
...
@@ -444,6 +447,7 @@ def train_loop(
Checkpoint every n training steps.
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
...
...
@@ -530,8 +534,11 @@ def train_loop(
# is the chief.
summary_writer_filepath
=
get_filepath
(
strategy
,
os
.
path
.
join
(
model_dir
,
'train'
))
summary_writer
=
tf
.
compat
.
v2
.
summary
.
create_file_writer
(
summary_writer_filepath
)
if
record_summaries
:
summary_writer
=
tf
.
compat
.
v2
.
summary
.
create_file_writer
(
summary_writer_filepath
)
else
:
summary_writer
=
tf2
.
summary
.
create_noop_writer
()
if
use_tpu
:
num_steps_per_iteration
=
100
...
...
@@ -603,7 +610,9 @@ def train_loop(
if
num_steps_per_iteration
>
1
:
for
_
in
tf
.
range
(
num_steps_per_iteration
-
1
):
_sample_and_train
(
strategy
,
train_step_fn
,
data_iterator
)
# Following suggestion on yaqs/5402607292645376
with
tf
.
name_scope
(
''
):
_sample_and_train
(
strategy
,
train_step_fn
,
data_iterator
)
return
_sample_and_train
(
strategy
,
train_step_fn
,
data_iterator
)
...
...
research/object_detection/model_main_tf2.py
View file @
7359586f
...
...
@@ -63,6 +63,11 @@ flags.DEFINE_integer(
'num_workers'
,
1
,
'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.'
)
flags
.
DEFINE_integer
(
'checkpoint_every_n'
,
1000
,
'Integer defining how often we checkpoint.'
)
flags
.
DEFINE_boolean
(
'record_summaries'
,
True
,
(
'Whether or not to record summaries during'
' training.'
))
FLAGS
=
flags
.
FLAGS
...
...
@@ -101,7 +106,9 @@ def main(unused_argv):
pipeline_config_path
=
FLAGS
.
pipeline_config_path
,
model_dir
=
FLAGS
.
model_dir
,
train_steps
=
FLAGS
.
num_train_steps
,
use_tpu
=
FLAGS
.
use_tpu
)
use_tpu
=
FLAGS
.
use_tpu
,
checkpoint_every_n
=
FLAGS
.
checkpoint_every_n
,
record_summaries
=
FLAGS
.
record_summaries
)
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
app
.
run
()
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
7359586f
...
...
@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
def
get_model
(
self
):
return
self
.
_network
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for Hourglass.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
7359586f
...
...
@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
return
1
def
get_model
(
self
):
return
self
.
_network
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for MobileNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
7359586f
...
...
@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
"""Get base resnet model for inspection and testing."""
return
self
.
_base_model
def
call
(
self
,
inputs
):
"""Returns image features extracted by the backbone.
...
...
@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
return
4
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v2 101 feature extractor."""
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
7359586f
...
...
@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
"""Get base resnet model for inspection and testing."""
return
self
.
_base_model
def
call
(
self
,
inputs
):
"""Returns image features extracted by the backbone.
...
...
@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
return
4
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet FPN.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v1 101 FPN feature extractor."""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment