Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5741cef6
Commit
5741cef6
authored
Apr 15, 2020
by
Yeqing Li
Committed by
A. Unique TensorFlower
Apr 15, 2020
Browse files
Support multiple metrics in CTL.
PiperOrigin-RevId: 306751755
parent
3c227a73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
26 deletions
+57
-26
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+57
-26
No files found.
official/modeling/training/distributed_executor.py
View file @
5741cef6
...
...
@@ -30,9 +30,9 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from
typing
import
Optional
,
Dict
,
List
,
Text
,
Callable
,
Union
,
Iterator
,
Any
from
official.modeling.hyperparams
import
params_dict
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils
import
hyperparams_flags
FLAGS
=
flags
.
FLAGS
...
...
@@ -59,6 +59,45 @@ def _no_metric():
return
None
def
metrics_as_dict
(
metric
):
"""Puts input metric(s) into a list.
Args:
metric: metric(s) to be put into the list. `metric` could be a object, a
list or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns:
A dictionary of valid metrics.
"""
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metrics
=
{
metric
.
name
:
metric
}
elif
isinstance
(
metric
,
list
):
metrics
=
{
m
.
name
:
m
for
m
in
metric
}
elif
isinstance
(
metric
,
dict
):
metrics
=
metric
elif
not
metric
:
return
{}
else
:
metrics
=
{
'metric'
:
metric
}
return
metrics
def
metric_results
(
metric
):
"""Collects results from the given metric(s)."""
metrics
=
metrics_as_dict
(
metric
)
metric_result
=
{
name
:
m
.
result
().
numpy
().
astype
(
float
)
for
name
,
m
in
metrics
.
items
()
}
return
metric_result
def
reset_states
(
metric
):
"""Resets states of the given metric(s)."""
metrics
=
metrics_as_dict
(
metric
)
for
m
in
metrics
.
values
():
m
.
reset_states
()
class
SummaryWriter
(
object
):
"""Simple SummaryWriter for writing dictionary of metrics.
...
...
@@ -185,6 +224,7 @@ class DistributedExecutor(object):
loss_fn
,
optimizer
,
metric
=
None
):
metrics
=
metrics_as_dict
(
metric
)
def
_replicated_step
(
inputs
):
"""Replicated training step."""
...
...
@@ -195,11 +235,8 @@ class DistributedExecutor(object):
prediction_loss
=
loss_fn
(
labels
,
outputs
)
loss
=
tf
.
reduce_mean
(
prediction_loss
)
loss
=
loss
/
strategy
.
num_replicas_in_sync
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metric
.
update_state
(
labels
,
outputs
)
else
:
logging
.
error
(
'train metric is not an instance of '
'tf.keras.metrics.Metric.'
)
for
m
in
metrics
.
values
():
m
.
update_state
(
labels
,
outputs
)
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
...
...
@@ -235,6 +272,7 @@ class DistributedExecutor(object):
Args:
iterator: an iterator that yields input tensors.
num_steps: the number of steps in the loop.
Returns:
The loss tensor.
...
...
@@ -259,6 +297,7 @@ class DistributedExecutor(object):
def
_create_test_step
(
self
,
strategy
,
model
,
metric
):
"""Creates a distributed test step."""
metrics
=
metrics_as_dict
(
metric
)
@
tf
.
function
def
test_step
(
iterator
):
...
...
@@ -266,22 +305,20 @@ class DistributedExecutor(object):
if
not
metric
:
logging
.
info
(
'Skip test_step because metric is None (%s)'
,
metric
)
return
None
,
None
if
not
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
ValueError
(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'
.
format
(
metric
))
def
_test_step_fn
(
inputs
):
"""Replicated accuracy calculation."""
inputs
,
labels
=
inputs
model_outputs
=
model
(
inputs
,
training
=
False
)
metric
.
update_state
(
labels
,
model_outputs
)
for
m
in
metrics
.
values
():
m
.
update_state
(
labels
,
model_outputs
)
return
labels
,
model_outputs
return
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
return
test_step
def
train
(
self
,
train_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
tf
.
data
.
Dataset
],
eval_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
...
...
@@ -422,8 +459,9 @@ class DistributedExecutor(object):
test_step
=
self
.
_create_test_step
(
strategy
,
model
,
metric
=
eval_metric
)
# Step-0 operations
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
current_step
==
0
and
not
latest_checkpoint_file
:
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_metric_result
=
self
.
_run_evaluation
(
...
...
@@ -432,7 +470,7 @@ class DistributedExecutor(object):
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
logging
.
info
(
'Training started'
)
last_save_checkpoint_step
=
current_step
...
...
@@ -454,12 +492,7 @@ class DistributedExecutor(object):
raise
ValueError
(
'total loss is NaN.'
)
if
train_metric
:
train_metric_result
=
train_metric
.
result
()
if
isinstance
(
train_metric
,
tf
.
keras
.
metrics
.
Metric
):
train_metric_result
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
().
astype
(
float
),
train_metric_result
)
if
not
isinstance
(
train_metric_result
,
dict
):
train_metric_result
=
{
'metric'
:
train_metric_result
}
train_metric_result
=
metric_results
(
train_metric
)
train_metric_result
.
update
(
train_loss
)
else
:
train_metric_result
=
train_loss
...
...
@@ -496,9 +529,9 @@ class DistributedExecutor(object):
# Re-initialize evaluation metric, except the last step.
if
eval_metric
and
current_step
<
total_steps
:
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
if
train_metric
and
current_step
<
total_steps
:
train_metric
.
reset_states
()
reset_states
(
train_metric
)
# Reaches the end of training and saves the last checkpoint.
if
last_save_checkpoint_step
<
total_steps
:
...
...
@@ -534,9 +567,7 @@ class DistributedExecutor(object):
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
break
metric_result
=
metric
.
result
()
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metric_result
=
metric_result
.
numpy
().
astype
(
float
)
metric_result
=
metric_results
(
metric
)
logging
.
info
(
'Step: [%d] Validation metric = %f'
,
current_training_step
,
metric_result
)
return
metric_result
...
...
@@ -653,7 +684,7 @@ class DistributedExecutor(object):
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
summary_writer
(
metrics
=
eval_metric_result
,
step
=
current_step
)
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
return
eval_metric_result
,
current_step
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment