Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2781377d
Commit
2781377d
authored
Feb 25, 2020
by
Zongwei Zhou
Committed by
A. Unique TensorFlower
Feb 25, 2020
Browse files
Internal change
PiperOrigin-RevId: 297222995
parent
ed2e3bc3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
106 additions
and
11 deletions
+106
-11
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+80
-3
official/modeling/model_training_utils_test.py
official/modeling/model_training_utils_test.py
+20
-7
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+4
-0
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+2
-1
No files found.
official/modeling/model_training_utils.py
View file @
2781377d
...
@@ -76,6 +76,56 @@ def write_txt_summary(training_summary, summary_dir):
...
@@ -76,6 +76,56 @@ def write_txt_summary(training_summary, summary_dir):
f
.
write
(
json
.
dumps
(
training_summary
,
indent
=
4
))
f
.
write
(
json
.
dumps
(
training_summary
,
indent
=
4
))
def
_filter_grads
(
grads_and_vars
):
"""Filter out iterable with grad equal to None."""
grads_and_vars
=
tuple
(
grads_and_vars
)
if
not
grads_and_vars
:
return
grads_and_vars
filtered
=
[]
vars_with_empty_grads
=
[]
for
grad
,
var
in
grads_and_vars
:
if
grad
is
None
:
vars_with_empty_grads
.
append
(
var
)
else
:
filtered
.
append
((
grad
,
var
))
filtered
=
tuple
(
filtered
)
if
not
filtered
:
raise
ValueError
(
'No gradients provided for any variable: %s.'
%
([
v
.
name
for
_
,
v
in
grads_and_vars
],))
if
vars_with_empty_grads
:
logging
.
warning
(
(
'Gradients do not exist for variables %s when minimizing the loss.'
),
([
v
.
name
for
v
in
vars_with_empty_grads
]))
return
filtered
def
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
'float32'
):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
all_reduce_sum_gradients=False).
Arguments:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
Returns:
pairs of allreduced non-None gradients and variables.
"""
filtered_grads_and_vars
=
_filter_grads
(
grads_and_vars
)
(
grads
,
variables
)
=
zip
(
*
filtered_grads_and_vars
)
if
allreduce_precision
==
'float16'
:
grads
=
[
tf
.
cast
(
grad
,
'float16'
)
for
grad
in
grads
]
allreduced_grads
=
tf
.
distribute
.
get_replica_context
().
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
grads
)
if
allreduce_precision
==
'float16'
:
allreduced_grads
=
[
tf
.
cast
(
grad
,
'float32'
)
for
grad
in
allreduced_grads
]
return
allreduced_grads
,
variables
def
run_customized_training_loop
(
def
run_customized_training_loop
(
# pylint: disable=invalid-name
# pylint: disable=invalid-name
_sentinel
=
None
,
_sentinel
=
None
,
...
@@ -94,7 +144,8 @@ def run_customized_training_loop(
...
@@ -94,7 +144,8 @@ def run_customized_training_loop(
init_checkpoint
=
None
,
init_checkpoint
=
None
,
custom_callbacks
=
None
,
custom_callbacks
=
None
,
run_eagerly
=
False
,
run_eagerly
=
False
,
sub_model_export_name
=
None
):
sub_model_export_name
=
None
,
explicit_allreduce
=
False
):
"""Run BERT pretrain model training using low-level API.
"""Run BERT pretrain model training using low-level API.
Arguments:
Arguments:
...
@@ -136,6 +187,12 @@ def run_customized_training_loop(
...
@@ -136,6 +187,12 @@ def run_customized_training_loop(
file is {sub_model_export_name}_step_{step}.ckpt and the last
file is {sub_model_export_name}_step_{step}.ckpt and the last
checkpint's name is {sub_model_export_name}.ckpt;
checkpint's name is {sub_model_export_name}.ckpt;
if None, `sub_model` will not be exported as checkpoint.
if None, `sub_model` will not be exported as checkpoint.
explicit_allreduce: Whether to explicitly perform gradient allreduce,
instead of relying on implicit allreduce in optimizer.apply_gradients().
default is False. For now, if training using FP16 mixed precision,
explicit allreduce will aggregate gradients in FP16 format. For TPU and
GPU training using FP32, explicit allreduce will aggregate gradients in
FP32 format.
Returns:
Returns:
Trained model.
Trained model.
...
@@ -251,10 +308,30 @@ def run_customized_training_loop(
...
@@ -251,10 +308,30 @@ def run_customized_training_loop(
if
use_float16
:
if
use_float16
:
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
training_vars
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
training_vars
)
if
explicit_allreduce
:
(
allreduced_scaled_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
zip
(
scaled_grads
,
training_vars
),
allreduce_precision
=
'float16'
)
allreduced_unscaled_grads
=
optimizer
.
get_unscaled_gradients
(
allreduced_scaled_grads
)
grads_and_vars
=
zip
(
allreduced_unscaled_grads
,
filtered_training_vars
)
else
:
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
grads_and_vars
=
zip
(
grads
,
training_vars
)
else
:
else
:
# TPU or FP32 GPU code path
grads
=
tape
.
gradient
(
loss
,
training_vars
)
grads
=
tape
.
gradient
(
loss
,
training_vars
)
optimizer
.
apply_gradients
(
zip
(
grads
,
training_vars
))
if
explicit_allreduce
:
(
allreduced_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
zip
(
grads
,
training_vars
),
allreduce_precision
=
'float32'
)
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
else
:
grads_and_vars
=
zip
(
grads
,
training_vars
)
optimizer
.
apply_gradients
(
grads_and_vars
,
all_reduce_sum_gradients
=
not
explicit_allreduce
)
# For reporting, the metric takes the mean of losses.
# For reporting, the metric takes the mean of losses.
train_loss_metric
.
update_state
(
loss
)
train_loss_metric
.
update_state
(
loss
)
for
metric
in
train_metrics
:
for
metric
in
train_metrics
:
...
...
official/modeling/model_training_utils_test.py
View file @
2781377d
...
@@ -139,7 +139,8 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -139,7 +139,8 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
super
(
ModelTrainingUtilsTest
,
self
).
setUp
()
super
(
ModelTrainingUtilsTest
,
self
).
setUp
()
self
.
_model_fn
=
create_model_fn
(
input_shape
=
[
128
],
num_classes
=
3
)
self
.
_model_fn
=
create_model_fn
(
input_shape
=
[
128
],
num_classes
=
3
)
def
run_training
(
self
,
strategy
,
model_dir
,
steps_per_loop
,
run_eagerly
):
def
run_training
(
self
,
strategy
,
model_dir
,
steps_per_loop
,
run_eagerly
,
explicit_allreduce
=
False
):
input_fn
=
create_fake_data_input_fn
(
input_fn
=
create_fake_data_input_fn
(
batch_size
=
8
,
features_shape
=
[
128
],
num_classes
=
3
)
batch_size
=
8
,
features_shape
=
[
128
],
num_classes
=
3
)
model_training_utils
.
run_customized_training_loop
(
model_training_utils
.
run_customized_training_loop
(
...
@@ -179,12 +180,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -179,12 +180,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
self
.
run_training
(
self
.
run_training
(
distribution
,
model_dir
,
steps_per_loop
=
1
,
run_eagerly
=
True
)
distribution
,
model_dir
,
steps_per_loop
=
1
,
run_eagerly
=
True
)
@
combinations
.
generate
(
eager_strategy_combinations
())
def
_verify_artifacts
(
self
,
model_dir
):
def
test_train_check_artifacts
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
self
.
run_training
(
distribution
,
model_dir
,
steps_per_loop
=
10
,
run_eagerly
=
False
)
# Two checkpoints should be saved after two epochs.
# Two checkpoints should be saved after two epochs.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'ctl_step_*'
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'ctl_step_*'
)))
self
.
assertNotEmpty
(
self
.
assertNotEmpty
(
...
@@ -208,6 +204,23 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -208,6 +204,23 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword
(
'mean_input'
,
check_eventfile_for_keyword
(
'mean_input'
,
os
.
path
.
join
(
model_dir
,
'summaries/eval'
)))
os
.
path
.
join
(
model_dir
,
'summaries/eval'
)))
@
combinations
.
generate
(
eager_strategy_combinations
())
def
test_train_check_artifacts
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
self
.
run_training
(
distribution
,
model_dir
,
steps_per_loop
=
10
,
run_eagerly
=
False
)
self
.
_verify_artifacts
(
model_dir
)
@
combinations
.
generate
(
eager_strategy_combinations
())
def
test_train_explicit_allreduce_check_artifacts
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
self
.
run_training
(
distribution
,
model_dir
,
steps_per_loop
=
10
,
run_eagerly
=
False
,
explicit_allreduce
=
True
)
self
.
_verify_artifacts
(
model_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
...
...
official/nlp/bert/common_flags.py
View file @
2781377d
...
@@ -68,6 +68,10 @@ def define_common_bert_flags():
...
@@ -68,6 +68,10 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.'
)
'If specified, init_checkpoint flag should not be used.'
)
flags
.
DEFINE_bool
(
'hub_module_trainable'
,
True
,
flags
.
DEFINE_bool
(
'hub_module_trainable'
,
True
,
'True to make keras layers in the hub module trainable.'
)
'True to make keras layers in the hub module trainable.'
)
flags
.
DEFINE_bool
(
'explicit_allreduce'
,
False
,
'Whether to explicit perform gradient allreduce in '
'training loop, instead of relying on implicit allreduce '
'in optimizer.apply_gradients().'
)
# Adds flags for mixed precision and multi-worker training.
# Adds flags for mixed precision and multi-worker training.
flags_core
.
define_performance
(
flags_core
.
define_performance
(
...
...
official/nlp/bert/run_squad_helper.py
View file @
2781377d
...
@@ -280,7 +280,8 @@ def train_squad(strategy,
...
@@ -280,7 +280,8 @@ def train_squad(strategy,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
run_eagerly
=
run_eagerly
,
run_eagerly
=
run_eagerly
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
,
explicit_allreduce
=
FLAGS
.
explicit_allreduce
)
def
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
def
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment