Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f9491103
Commit
f9491103
authored
Mar 22, 2021
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Mar 22, 2021
Browse files
Add test for performance exporter
PiperOrigin-RevId: 364363163
parent
0b7674b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
6 deletions
+47
-6
research/object_detection/model_lib_tf2_test.py
research/object_detection/model_lib_tf2_test.py
+46
-5
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+1
-1
No files found.
research/object_detection/model_lib_tf2_test.py
View file @
f9491103
...
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os
import
tempfile
import
unittest
...
...
@@ -26,9 +27,9 @@ import six
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v2
as
tf2
from
object_detection
import
exporter_lib_v2
from
object_detection
import
inputs
from
object_detection
import
model_lib_v2
from
object_detection.builders
import
model_builder
from
object_detection.core
import
model
from
object_detection.protos
import
train_pb2
from
object_detection.utils
import
config_util
...
...
@@ -145,6 +146,12 @@ class SimpleModel(model.DetectionModel):
return
[]
def
fake_model_builder
(
*
_
,
**
__
):
return
SimpleModel
()
FAKE_BUILDER_MAP
=
{
'build'
:
fake_model_builder
}
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
ModelCheckpointTest
(
tf
.
test
.
TestCase
):
"""Test for model checkpoint related functionality."""
...
...
@@ -153,10 +160,9 @@ class ModelCheckpointTest(tf.test.TestCase):
"""Test that only the most recent checkpoints are kept."""
strategy
=
tf2
.
distribute
.
OneDeviceStrategy
(
device
=
'/cpu:0'
)
with
mock
.
patch
.
object
(
model_builder
,
'build'
,
autospec
=
True
)
as
mock_builder
:
with
strategy
.
scope
():
mock_builder
.
return_value
=
SimpleModel
()
with
mock
.
patch
.
dict
(
exporter_lib_v2
.
INPUT_BUILDER_UTIL_MAP
,
FAKE_BUILDER_MAP
):
model_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
pipeline_config_path
=
get_pipeline_config_path
(
MODEL_NAME_FOR_TEST
)
new_pipeline_config_path
=
os
.
path
.
join
(
model_dir
,
'new_pipeline.config'
)
...
...
@@ -226,5 +232,40 @@ class CheckpointV2Test(tf.test.TestCase):
unpad_groundtruth_tensors
=
True
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
MetricsExportTest
(
tf
.
test
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
# pylint:disable=g-missing-super-call
tf
.
keras
.
backend
.
clear_session
()
def
test_export_metrics_json_serializable
(
self
):
"""Tests that Estimator and input function are constructed correctly."""
strategy
=
tf2
.
distribute
.
OneDeviceStrategy
(
device
=
'/cpu:0'
)
def
export
(
data
,
_
):
json
.
dumps
(
data
)
with
mock
.
patch
.
dict
(
exporter_lib_v2
.
INPUT_BUILDER_UTIL_MAP
,
FAKE_BUILDER_MAP
):
with
strategy
.
scope
():
model_dir
=
tf
.
test
.
get_temp_dir
()
new_pipeline_config_path
=
os
.
path
.
join
(
model_dir
,
'new_pipeline.config'
)
pipeline_config_path
=
get_pipeline_config_path
(
MODEL_NAME_FOR_TEST
)
config_util
.
clear_fine_tune_checkpoint
(
pipeline_config_path
,
new_pipeline_config_path
)
train_steps
=
2
with
strategy
.
scope
():
model_lib_v2
.
train_loop
(
new_pipeline_config_path
,
model_dir
=
model_dir
,
train_steps
=
train_steps
,
checkpoint_every_n
=
100
,
performance_summary_exporter
=
export
,
**
_get_config_kwarg_overrides
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/model_lib_v2.py
View file @
f9491103
...
...
@@ -686,7 +686,7 @@ def train_loop(
'steps_per_sec'
:
np
.
mean
(
steps_per_sec_list
),
'steps_per_sec_p50'
:
np
.
median
(
steps_per_sec_list
),
'steps_per_sec_max'
:
max
(
steps_per_sec_list
),
'last_batch_loss'
:
loss
'last_batch_loss'
:
float
(
loss
)
}
mixed_precision
=
'bf16'
if
kwargs
[
'use_bfloat16'
]
else
'fp32'
performance_summary_exporter
(
metrics
,
mixed_precision
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment