Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cf01596c
"vscode:/vscode.git/clone" did not exist on "ec6f8ef99a8c6942133e01a610def197e1d6d9dd"
Commit
cf01596c
authored
Mar 04, 2020
by
Zongwei Zhou
Committed by
A. Unique TensorFlower
Mar 04, 2020
Browse files
Internal change
PiperOrigin-RevId: 299007295
parent
7a3d6c4c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
200 additions
and
24 deletions
+200
-24
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+36
-11
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+13
-1
official/nlp/optimization.py
official/nlp/optimization.py
+7
-1
official/staging/training/grad_utils.py
official/staging/training/grad_utils.py
+141
-0
official/vision/image_classification/resnet_runnable.py
official/vision/image_classification/resnet_runnable.py
+3
-11
No files found.
official/modeling/model_training_utils.py
View file @
cf01596c
...
...
@@ -23,6 +23,7 @@ import os
from
absl
import
logging
import
tensorflow
as
tf
from
official.staging.training
import
grad_utils
from
official.utils.misc
import
distribution_utils
_SUMMARY_TXT
=
'training_summary.txt'
...
...
@@ -94,7 +95,10 @@ def run_customized_training_loop(
init_checkpoint
=
None
,
custom_callbacks
=
None
,
run_eagerly
=
False
,
sub_model_export_name
=
None
):
sub_model_export_name
=
None
,
explicit_allreduce
=
False
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
):
"""Run BERT pretrain model training using low-level API.
Arguments:
...
...
@@ -136,6 +140,23 @@ def run_customized_training_loop(
file is {sub_model_export_name}_step_{step}.ckpt and the last
checkpint's name is {sub_model_export_name}.ckpt;
if None, `sub_model` will not be exported as checkpoint.
explicit_allreduce: Whether to explicitly perform gradient allreduce,
instead of relying on implicit allreduce in optimizer.apply_gradients().
default is False. For now, if training using FP16 mixed precision,
explicit allreduce will aggregate gradients in FP16 format. For TPU and
GPU training using FP32, explicit allreduce will aggregate gradients in
FP32 format.
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables paris. The callback functions will be
invoked in the list order and before gradients are allreduced.
Default is no callbacks. Only used when explicit_allreduce=True.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. Only used
when explicit_allreduce=True.
Returns:
Trained model.
...
...
@@ -199,8 +220,6 @@ def run_customized_training_loop(
'sub_model is None.'
%
sub_model_export_name
)
optimizer
=
model
.
optimizer
use_float16
=
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
)
if
init_checkpoint
:
logging
.
info
(
...
...
@@ -242,15 +261,21 @@ def run_customized_training_loop(
with
tf
.
GradientTape
()
as
tape
:
model_outputs
=
model
(
inputs
,
training
=
True
)
loss
=
loss_fn
(
labels
,
model_outputs
)
if
use_float16
:
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
if
use_float16
:
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
training_vars
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
if
explicit_allreduce
:
grad_utils
.
minimize_using_explicit_allreduce
(
tape
,
optimizer
,
loss
,
training_vars
,
pre_allreduce_callbacks
,
post_allreduce_callbacks
)
else
:
grads
=
tape
.
gradient
(
loss
,
training_vars
)
optimizer
.
apply_gradients
(
zip
(
grads
,
training_vars
))
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
with
tape
:
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
training_vars
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
else
:
grads
=
tape
.
gradient
(
loss
,
training_vars
)
optimizer
.
apply_gradients
(
zip
(
grads
,
training_vars
))
# For reporting, the metric takes the mean of losses.
train_loss_metric
.
update_state
(
loss
)
for
metric
in
train_metrics
:
...
...
official/nlp/bert/run_squad_helper.py
View file @
cf01596c
...
...
@@ -269,6 +269,16 @@ def train_squad(strategy,
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
)
# when all_reduce_sum_gradients = False, apply_gradients() no longer
# implicitly allreduce gradients, users manually allreduce gradient and
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm
# will be moved to before users' manual allreduce to keep the math
# unchanged.
def
clip_by_global_norm_callback
(
grads_and_vars
):
grads
,
variables
=
zip
(
*
grads_and_vars
)
(
clipped_grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
return
zip
(
clipped_grads
,
variables
)
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
model_fn
=
_get_squad_model
,
...
...
@@ -280,7 +290,9 @@ def train_squad(strategy,
train_input_fn
=
train_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
run_eagerly
=
run_eagerly
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
,
explicit_allreduce
=
True
,
pre_allreduce_callbacks
=
[
clip_by_global_norm_callback
])
def
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
...
...
official/nlp/optimization.py
View file @
cf01596c
...
...
@@ -142,7 +142,13 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
name
=
None
,
all_reduce_sum_gradients
=
True
):
grads
,
tvars
=
list
(
zip
(
*
grads_and_vars
))
(
grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
if
all_reduce_sum_gradients
:
# when all_reduce_sum_gradients = False, apply_gradients() no longer
# implicitly allreduce gradients, users manually allreduce gradient and
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm
# will be moved to before the explicit allreduce to keep the math
# the same as TF 1 and pre TF 2.2 implementation.
(
grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
return
super
(
AdamWeightDecay
,
self
).
apply_gradients
(
zip
(
grads
,
tvars
),
name
=
name
,
...
...
official/staging/training/grad_utils.py
0 → 100644
View file @
cf01596c
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some gradient util functions to help users writing custom training loop."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
absl
import
logging
import
tensorflow.compat.v2
as
tf
def
_filter_grads
(
grads_and_vars
):
"""Filter out iterable with grad equal to None."""
grads_and_vars
=
tuple
(
grads_and_vars
)
if
not
grads_and_vars
:
return
grads_and_vars
filtered
=
[]
vars_with_empty_grads
=
[]
for
grad
,
var
in
grads_and_vars
:
if
grad
is
None
:
vars_with_empty_grads
.
append
(
var
)
else
:
filtered
.
append
((
grad
,
var
))
filtered
=
tuple
(
filtered
)
if
not
filtered
:
raise
ValueError
(
"No gradients provided for any variable: %s."
%
([
v
.
name
for
_
,
v
in
grads_and_vars
],))
if
vars_with_empty_grads
:
logging
.
warning
(
(
"Gradients do not exist for variables %s when minimizing the loss."
),
([
v
.
name
for
v
in
vars_with_empty_grads
]))
return
filtered
def
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
all_reduce_sum_gradients=False).
Arguments:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
Returns:
pairs of allreduced non-None gradients and variables.
"""
filtered_grads_and_vars
=
_filter_grads
(
grads_and_vars
)
(
grads
,
variables
)
=
zip
(
*
filtered_grads_and_vars
)
if
allreduce_precision
==
"float16"
:
grads
=
[
tf
.
cast
(
grad
,
"float16"
)
for
grad
in
grads
]
allreduced_grads
=
tf
.
distribute
.
get_replica_context
().
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
grads
)
if
allreduce_precision
==
"float16"
:
allreduced_grads
=
[
tf
.
cast
(
grad
,
"float32"
)
for
grad
in
allreduced_grads
]
return
allreduced_grads
,
variables
def
_run_callbacks
(
callbacks
,
grads_and_vars
):
for
callback
in
callbacks
:
grads_and_vars
=
callback
(
grads_and_vars
)
return
grads_and_vars
def
minimize_using_explicit_allreduce
(
tape
,
optimizer
,
loss
,
trainable_variables
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
):
"""Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`.
This explicitly performs gradient allreduce, instead of relying on implicit
allreduce in optimizer.apply_gradients(). If training using FP16 mixed
precision, explicit allreduce will aggregate gradients in FP16 format.
For TPU and GPU training using FP32, explicit allreduce will aggregate
gradients in FP32 format.
Arguments:
tape: An instance of `tf.GradientTape`.
optimizer: An instance of `tf.keras.optimizers.Optimizer`.
loss: the loss tensor.
trainable_variables: A list of model Variables.
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced.
Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks.
"""
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
# FP16 GPU code path
with
tape
:
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
trainable_variables
)
grads_and_vars
=
zip
(
scaled_grads
,
trainable_variables
)
if
pre_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_scaled_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float16"
)
allreduced_unscaled_grads
=
optimizer
.
get_unscaled_gradients
(
allreduced_scaled_grads
)
grads_and_vars
=
zip
(
allreduced_unscaled_grads
,
filtered_training_vars
)
else
:
# TPU or FP32 GPU code path
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
grads_and_vars
=
zip
(
grads
,
trainable_variables
)
if
pre_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
)
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
if
post_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
post_allreduce_callbacks
,
grads_and_vars
)
optimizer
.
apply_gradients
(
grads_and_vars
,
all_reduce_sum_gradients
=
False
)
official/vision/image_classification/resnet_runnable.py
View file @
cf01596c
...
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
tensorflow.compat.v2
as
tf
from
official.modeling
import
performance
from
official.staging.training
import
grad_utils
from
official.staging.training
import
standard_runnable
from
official.staging.training
import
utils
from
official.utils.flags
import
core
as
flags_core
...
...
@@ -170,17 +171,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
else
:
loss
+=
(
tf
.
reduce_sum
(
self
.
model
.
losses
)
/
num_replicas
)
# Scale the loss
if
self
.
flags_obj
.
dtype
==
'fp16'
:
loss
=
self
.
optimizer
.
get_scaled_loss
(
loss
)
grads
=
tape
.
gradient
(
loss
,
self
.
model
.
trainable_variables
)
# Unscale the grads
if
self
.
flags_obj
.
dtype
==
'fp16'
:
grads
=
self
.
optimizer
.
get_unscaled_gradients
(
grads
)
self
.
optimizer
.
apply_gradients
(
zip
(
grads
,
self
.
model
.
trainable_variables
))
grad_utils
.
minimize_using_explicit_allreduce
(
tape
,
self
.
optimizer
,
loss
,
self
.
model
.
trainable_variables
)
self
.
train_loss
.
update_state
(
loss
)
self
.
train_accuracy
.
update_state
(
labels
,
logits
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment