Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
172bf8ff
"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "8ce202a493294f8d98660c86d502edbfad74b741"
Commit
172bf8ff
authored
Aug 28, 2020
by
Zongwei Zhou
Committed by
A. Unique TensorFlower
Aug 28, 2020
Browse files
Internal change
PiperOrigin-RevId: 329042049
parent
eee5ca5f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
18 deletions
+53
-18
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+7
-0
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+9
-2
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+11
-6
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+8
-5
official/staging/training/grad_utils.py
official/staging/training/grad_utils.py
+18
-5
No files found.
official/nlp/bert/common_flags.py
View file @
172bf8ff
...
...
@@ -82,6 +82,13 @@ def define_common_bert_flags():
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.'
)
flags
.
DEFINE_integer
(
'allreduce_bytes_per_pack'
,
0
,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.'
)
flags_core
.
define_log_steps
()
...
...
official/nlp/bert/model_training_utils.py
View file @
172bf8ff
...
...
@@ -133,7 +133,8 @@ def run_customized_training_loop(
explicit_allreduce
=
False
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
,
train_summary_interval
=
0
):
train_summary_interval
=
0
,
allreduce_bytes_per_pack
=
0
):
"""Run BERT pretrain model training using low-level API.
Arguments:
...
...
@@ -201,6 +202,11 @@ def run_customized_training_loop(
when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack. Breaking gradient into packs could enable overlap between
allreduce and backprop computation. This flag only takes effect when
explicit_allreduce is set to True.'
Returns:
Trained model.
...
...
@@ -332,7 +338,8 @@ def run_customized_training_loop(
grad_utils
.
minimize_using_explicit_allreduce
(
tape
,
optimizer
,
loss
,
training_vars
,
pre_allreduce_callbacks
,
post_allreduce_callbacks
)
post_allreduce_callbacks
,
allreduce_bytes_per_pack
)
else
:
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
...
...
official/nlp/bert/run_pretraining.py
View file @
172bf8ff
...
...
@@ -109,7 +109,8 @@ def run_customized_training(strategy,
custom_callbacks
=
None
,
explicit_allreduce
=
False
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
):
post_allreduce_callbacks
=
None
,
allreduce_bytes_per_pack
=
0
):
"""Run BERT pretrain model training using low-level API."""
train_input_fn
=
get_pretrain_dataset_fn
(
input_files
,
max_seq_length
,
...
...
@@ -146,6 +147,7 @@ def run_customized_training(strategy,
explicit_allreduce
=
explicit_allreduce
,
pre_allreduce_callbacks
=
pre_allreduce_callbacks
,
post_allreduce_callbacks
=
post_allreduce_callbacks
,
allreduce_bytes_per_pack
=
allreduce_bytes_per_pack
,
train_summary_interval
=
train_summary_interval
,
custom_callbacks
=
custom_callbacks
)
...
...
@@ -165,10 +167,12 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
# before allreduce, to be consistent with original TF1 model.
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient and
# pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
return
run_customized_training
(
strategy
,
bert_config
,
...
...
@@ -191,7 +195,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
explicit_allreduce
=
FLAGS
.
explicit_allreduce
,
pre_allreduce_callbacks
=
[
model_training_utils
.
clip_by_global_norm_callback
])
],
allreduce_bytes_per_pack
=
FLAGS
.
allreduce_bytes_per_pack
)
def
main
(
_
):
...
...
official/nlp/bert/run_squad_helper.py
View file @
172bf8ff
...
...
@@ -260,10 +260,12 @@ def train_squad(strategy,
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
return
squad_model
,
core_model
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
# before allreduce, to be consistent with the original TF1 model.
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient and
# pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
model_fn
=
_get_squad_model
,
...
...
@@ -280,7 +282,8 @@ def train_squad(strategy,
explicit_allreduce
=
FLAGS
.
explicit_allreduce
,
pre_allreduce_callbacks
=
[
model_training_utils
.
clip_by_global_norm_callback
])
],
allreduce_bytes_per_pack
=
FLAGS
.
allreduce_bytes_per_pack
)
def
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
squad_lib
,
...
...
official/staging/training/grad_utils.py
View file @
172bf8ff
...
...
@@ -48,7 +48,8 @@ def _filter_grads(grads_and_vars):
def
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
):
allreduce_precision
=
"float32"
,
bytes_per_pack
=
0
):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
...
...
@@ -59,6 +60,8 @@ def _filter_and_allreduce_gradients(grads_and_vars,
Arguments:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, all gradients are in one pack.
Returns:
pairs of allreduced non-None gradients and variables.
...
...
@@ -67,8 +70,10 @@ def _filter_and_allreduce_gradients(grads_and_vars,
(
grads
,
variables
)
=
zip
(
*
filtered_grads_and_vars
)
if
allreduce_precision
==
"float16"
:
grads
=
[
tf
.
cast
(
grad
,
"float16"
)
for
grad
in
grads
]
hints
=
tf
.
distribute
.
experimental
.
CollectiveHints
(
bytes_per_pack
=
bytes_per_pack
)
allreduced_grads
=
tf
.
distribute
.
get_replica_context
().
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
grads
)
tf
.
distribute
.
ReduceOp
.
SUM
,
grads
,
experimental_hints
=
hints
)
if
allreduce_precision
==
"float16"
:
allreduced_grads
=
[
tf
.
cast
(
grad
,
"float32"
)
for
grad
in
allreduced_grads
]
return
allreduced_grads
,
variables
...
...
@@ -85,7 +90,8 @@ def minimize_using_explicit_allreduce(tape,
loss
,
trainable_variables
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
):
post_allreduce_callbacks
=
None
,
allreduce_bytes_per_pack
=
0
):
"""Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`.
...
...
@@ -111,6 +117,9 @@ def minimize_using_explicit_allreduce(tape,
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack.
"""
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
...
...
@@ -123,7 +132,9 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_scaled_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float16"
)
grads_and_vars
,
allreduce_precision
=
"float16"
,
bytes_per_pack
=
allreduce_bytes_per_pack
)
allreduced_unscaled_grads
=
optimizer
.
get_unscaled_gradients
(
allreduced_scaled_grads
)
grads_and_vars
=
zip
(
allreduced_unscaled_grads
,
filtered_training_vars
)
...
...
@@ -135,7 +146,9 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
)
grads_and_vars
,
allreduce_precision
=
"float32"
,
bytes_per_pack
=
allreduce_bytes_per_pack
)
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
if
post_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
post_allreduce_callbacks
,
grads_and_vars
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment