Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
acea25b9
Commit
acea25b9
authored
Feb 18, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 295849975
parent
f3600cd1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
9 deletions
+63
-9
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+13
-4
official/vision/detection/configs/retinanet_config.py
official/vision/detection/configs/retinanet_config.py
+3
-0
official/vision/detection/executor/detection_executor.py
official/vision/detection/executor/detection_executor.py
+30
-5
official/vision/detection/main.py
official/vision/detection/main.py
+1
-0
official/vision/detection/utils/box_utils.py
official/vision/detection/utils/box_utils.py
+16
-0
No files found.
official/modeling/training/distributed_executor.py
View file @
acea25b9
...
...
@@ -64,7 +64,7 @@ class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
Attributes:
_
writer: The tf.SummaryWriter.
writer: The tf.SummaryWriter.
"""
def
__init__
(
self
,
model_dir
:
Text
,
name
:
Text
):
...
...
@@ -74,7 +74,7 @@ class SummaryWriter(object):
model_dir: the model folder path.
name: the summary subfolder name.
"""
self
.
_
writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
name
))
self
.
writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
name
))
def
__call__
(
self
,
metrics
:
Union
[
Dict
[
Text
,
float
],
float
],
step
:
int
):
"""Write metrics to summary with the given writer.
...
...
@@ -88,10 +88,10 @@ class SummaryWriter(object):
logging
.
warning
(
'Warning: summary writer prefer metrics as dictionary.'
)
metrics
=
{
'metric'
:
metrics
}
with
self
.
_
writer
.
as_default
():
with
self
.
writer
.
as_default
():
for
k
,
v
in
metrics
.
items
():
tf
.
summary
.
scalar
(
k
,
v
,
step
=
step
)
self
.
_
writer
.
flush
()
self
.
writer
.
flush
()
class
DistributedExecutor
(
object
):
...
...
@@ -122,6 +122,9 @@ class DistributedExecutor(object):
self
.
_strategy
=
strategy
self
.
_checkpoint_name
=
'ctl_step_{step}.ckpt'
self
.
_is_multi_host
=
is_multi_host
self
.
train_summary_writer
=
None
self
.
eval_summary_writer
=
None
self
.
global_train_step
=
None
@
property
def
checkpoint_name
(
self
):
...
...
@@ -395,7 +398,10 @@ class DistributedExecutor(object):
eval_metric
=
eval_metric_fn
()
train_metric
=
train_metric_fn
()
train_summary_writer
=
summary_writer_fn
(
model_dir
,
'eval_train'
)
self
.
train_summary_writer
=
train_summary_writer
.
writer
test_summary_writer
=
summary_writer_fn
(
model_dir
,
'eval_test'
)
self
.
eval_summary_writer
=
test_summary_writer
.
writer
# Continue training loop.
train_step
=
self
.
_create_train_step
(
...
...
@@ -406,6 +412,7 @@ class DistributedExecutor(object):
metric
=
train_metric
)
test_step
=
None
if
eval_input_fn
and
eval_metric
:
self
.
global_train_step
=
model
.
optimizer
.
iterations
test_step
=
self
.
_create_test_step
(
strategy
,
model
,
metric
=
eval_metric
)
logging
.
info
(
'Training started'
)
...
...
@@ -549,6 +556,7 @@ class DistributedExecutor(object):
return
True
summary_writer
=
summary_writer_fn
(
model_dir
,
'eval'
)
self
.
eval_summary_writer
=
summary_writer
.
writer
# Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses.
...
...
@@ -615,6 +623,7 @@ class DistributedExecutor(object):
'checkpoint'
,
checkpoint_path
)
checkpoint
.
restore
(
checkpoint_path
)
self
.
global_train_step
=
model
.
optimizer
.
iterations
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
...
...
official/vision/detection/configs/retinanet_config.py
View file @
acea25b9
...
...
@@ -70,6 +70,9 @@ RETINANET_CFG = {
'val_json_file'
:
''
,
'eval_file_pattern'
:
''
,
'input_sharding'
:
True
,
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize'
:
0
,
},
'predict'
:
{
'predict_batch_size'
:
8
,
...
...
official/vision/detection/executor/detection_executor.py
View file @
acea25b9
...
...
@@ -25,6 +25,7 @@ import os
import
json
import
tensorflow.compat.v2
as
tf
from
official.modeling.training
import
distributed_executor
as
executor
from
official.vision.detection.utils
import
box_utils
class
DetectionDistributedExecutor
(
executor
.
DistributedExecutor
):
...
...
@@ -38,13 +39,19 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables_filter
=
None
,
**
kwargs
):
super
(
DetectionDistributedExecutor
,
self
).
__init__
(
**
kwargs
)
params
=
kwargs
[
'params'
]
if
predict_post_process_fn
:
assert
callable
(
predict_post_process_fn
)
if
trainable_variables_filter
:
assert
callable
(
trainable_variables_filter
)
self
.
_predict_post_process_fn
=
predict_post_process_fn
self
.
_trainable_variables_filter
=
trainable_variables_filter
self
.
eval_steps
=
tf
.
Variable
(
0
,
trainable
=
False
,
dtype
=
tf
.
int32
,
synchronization
=
tf
.
VariableSynchronization
.
ON_READ
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
shape
=
[])
def
_create_replicated_step
(
self
,
strategy
,
...
...
@@ -90,24 +97,41 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
"""Creates a distributed test step."""
@
tf
.
function
def
test_step
(
iterator
):
def
test_step
(
iterator
,
eval_steps
):
"""Calculates evaluation metrics on distributed devices."""
def
_test_step_fn
(
inputs
):
def
_test_step_fn
(
inputs
,
eval_steps
):
"""Replicated accuracy calculation."""
inputs
,
labels
=
inputs
model_outputs
=
model
(
inputs
,
training
=
False
)
if
self
.
_predict_post_process_fn
:
labels
,
prediction_outputs
=
self
.
_predict_post_process_fn
(
labels
,
model_outputs
)
num_remaining_visualizations
=
(
self
.
_params
.
eval
.
num_images_to_visualize
-
eval_steps
)
# If there are remaining number of visualizations that needs to be
# done, add next batch outputs for visualization.
#
# TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
# write correct slice of outputs to summary file.
if
num_remaining_visualizations
>
0
:
box_utils
.
visualize_bounding_boxes
(
inputs
,
prediction_outputs
[
'detection_boxes'
],
self
.
global_train_step
,
self
.
eval_summary_writer
)
return
labels
,
prediction_outputs
labels
,
outputs
=
strategy
.
experimental_run_v2
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
_test_step_fn
,
args
=
(
next
(
iterator
),
eval_steps
,
))
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
)
labels
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
labels
)
eval_steps
.
assign_add
(
self
.
_params
.
eval
.
batch_size
)
return
labels
,
outputs
return
test_step
...
...
@@ -115,6 +139,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
def
_run_evaluation
(
self
,
test_step
,
current_training_step
,
metric
,
test_iterator
):
"""Runs validation steps and aggregate metrics."""
self
.
eval_steps
.
assign
(
0
)
if
not
test_iterator
or
not
metric
:
logging
.
warning
(
'Both test_iterator (%s) and metrics (%s) must not be None.'
,
...
...
@@ -123,7 +148,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
logging
.
info
(
'Running evaluation after step: %s.'
,
current_training_step
)
while
True
:
try
:
labels
,
outputs
=
test_step
(
test_iterator
)
labels
,
outputs
=
test_step
(
test_iterator
,
self
.
eval_steps
)
if
metric
:
metric
.
update_state
(
labels
,
outputs
)
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
...
...
official/vision/detection/main.py
View file @
acea25b9
...
...
@@ -239,4 +239,5 @@ def main(argv):
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
config
.
set_soft_device_placement
(
True
)
app
.
run
(
main
)
official/vision/detection/utils/box_utils.py
View file @
acea25b9
...
...
@@ -26,6 +26,22 @@ EPSILON = 1e-8
BBOX_XFORM_CLIP
=
np
.
log
(
1000.
/
16.
)
def
visualize_images_with_bounding_boxes
(
images
,
box_outputs
,
step
,
summary_writer
):
"""Records subset of evaluation images with bounding boxes."""
image_shape
=
tf
.
shape
(
images
[
0
])
image_height
=
tf
.
cast
(
image_shape
[
0
],
tf
.
float32
)
image_width
=
tf
.
cast
(
image_shape
[
1
],
tf
.
float32
)
normalized_boxes
=
normalize_boxes
(
box_outputs
,
[
image_height
,
image_width
])
bounding_box_color
=
tf
.
constant
([[
1.0
,
1.0
,
0.0
,
1.0
]])
image_summary
=
tf
.
image
.
draw_bounding_boxes
(
images
,
normalized_boxes
,
bounding_box_color
)
with
summary_writer
.
as_default
():
tf
.
summary
.
image
(
'bounding_box_summary'
,
image_summary
,
step
=
step
)
summary_writer
.
flush
()
def
yxyx_to_xywh
(
boxes
):
"""Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment