Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bf6d6f6f
Commit
bf6d6f6f
authored
Mar 17, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 17, 2020
Browse files
Adds scale_loss to run_customized_training_loop, which should be the correct treatment.
PiperOrigin-RevId: 301441181
parent
4f58b1f7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
36 deletions
+21
-36
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+10
-1
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+3
-15
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+4
-5
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+4
-15
No files found.
official/modeling/model_training_utils.py
View file @
bf6d6f6f
...
...
@@ -102,6 +102,7 @@ def run_customized_training_loop(
strategy
=
None
,
model_fn
=
None
,
loss_fn
=
None
,
scale_loss
=
True
,
model_dir
=
None
,
train_input_fn
=
None
,
steps_per_epoch
=
None
,
...
...
@@ -129,6 +130,8 @@ def run_customized_training_loop(
to be used for initial checkpoint -- if provided.
loss_fn: Function with signature func(labels, logits) and returns a loss
tensor.
scale_loss: Whether to divide the raw loss by number of replicas before
gradients calculation.
model_dir: Model directory used during training for restoring/saving model
weights.
train_input_fn: Function that returns a tf.data.Dataset used for training.
...
...
@@ -284,6 +287,12 @@ def run_customized_training_loop(
with
tf
.
GradientTape
()
as
tape
:
model_outputs
=
model
(
inputs
,
training
=
True
)
loss
=
loss_fn
(
labels
,
model_outputs
)
# Raw loss is used for reporting in metrics/logs.
raw_loss
=
loss
if
scale_loss
:
# Scales down the loss for gradients to be invariant from replicas.
loss
=
loss
/
strategy
.
num_replicas_in_sync
if
explicit_allreduce
:
grad_utils
.
minimize_using_explicit_allreduce
(
tape
,
optimizer
,
loss
,
training_vars
,
...
...
@@ -300,7 +309,7 @@ def run_customized_training_loop(
grads
=
tape
.
gradient
(
loss
,
training_vars
)
optimizer
.
apply_gradients
(
zip
(
grads
,
training_vars
))
# For reporting, the metric takes the mean of losses.
train_loss_metric
.
update_state
(
loss
)
train_loss_metric
.
update_state
(
raw_
loss
)
for
metric
in
train_metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
...
...
official/nlp/bert/run_classifier.py
View file @
bf6d6f6f
...
...
@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
FLAGS
=
flags
.
FLAGS
def
get_loss_fn
(
num_classes
,
loss_factor
=
1.0
):
def
get_loss_fn
(
num_classes
):
"""Gets the classification loss function."""
def
classification_loss_fn
(
labels
,
logits
):
...
...
@@ -72,9 +72,7 @@ def get_loss_fn(num_classes, loss_factor=1.0):
tf
.
cast
(
labels
,
dtype
=
tf
.
int32
),
depth
=
num_classes
,
dtype
=
tf
.
float32
)
per_example_loss
=
-
tf
.
reduce_sum
(
tf
.
cast
(
one_hot_labels
,
dtype
=
tf
.
float32
)
*
log_probs
,
axis
=-
1
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
loss
*=
loss_factor
return
loss
return
tf
.
reduce_mean
(
per_example_loss
)
return
classification_loss_fn
...
...
@@ -135,17 +133,7 @@ def run_bert_classifier(strategy,
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
return
classifier_model
,
core_model
# During distributed training, loss used for gradient computation is
# summed over from all replicas. When Keras compile/fit() API is used,
# the fit() API internally normalizes the loss by dividing the loss by
# the number of replicas used for computation. However, when custom
# training loop is used this is not done automatically and should be
# done manually by the end user.
loss_multiplier
=
1.0
if
FLAGS
.
scale_loss
and
not
use_keras_compile_fit
:
loss_multiplier
=
1.0
/
strategy
.
num_replicas_in_sync
loss_fn
=
get_loss_fn
(
num_classes
,
loss_factor
=
loss_multiplier
)
loss_fn
=
get_loss_fn
(
num_classes
)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
...
...
official/nlp/bert/run_pretraining.py
View file @
bf6d6f6f
...
...
@@ -74,11 +74,11 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
return
_dataset_fn
def
get_loss_fn
(
loss_factor
=
1.0
):
def
get_loss_fn
():
"""Returns loss function for BERT pretraining."""
def
_bert_pretrain_loss_fn
(
unused_labels
,
losses
,
**
unused_args
):
return
tf
.
reduce_mean
(
losses
)
*
loss_factor
return
tf
.
reduce_mean
(
losses
)
return
_bert_pretrain_loss_fn
...
...
@@ -116,9 +116,8 @@ def run_customized_training(strategy,
trained_model
=
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
model_fn
=
_get_pretrain_model
,
loss_fn
=
get_loss_fn
(
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
),
loss_fn
=
get_loss_fn
(),
scale_loss
=
FLAGS
.
scale_loss
,
model_dir
=
model_dir
,
train_input_fn
=
train_input_fn
,
steps_per_epoch
=
steps_per_epoch
,
...
...
official/nlp/bert/run_squad_helper.py
View file @
bf6d6f6f
...
...
@@ -90,8 +90,7 @@ FLAGS = flags.FLAGS
def
squad_loss_fn
(
start_positions
,
end_positions
,
start_logits
,
end_logits
,
loss_factor
=
1.0
):
end_logits
):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
start_positions
,
start_logits
,
from_logits
=
True
)
...
...
@@ -99,11 +98,10 @@ def squad_loss_fn(start_positions,
end_positions
,
end_logits
,
from_logits
=
True
)
total_loss
=
(
tf
.
reduce_mean
(
start_loss
)
+
tf
.
reduce_mean
(
end_loss
))
/
2
total_loss
*=
loss_factor
return
total_loss
def
get_loss_fn
(
loss_factor
=
1.0
):
def
get_loss_fn
():
"""Gets a loss function for squad task."""
def
_loss_fn
(
labels
,
model_outputs
):
...
...
@@ -114,8 +112,7 @@ def get_loss_fn(loss_factor=1.0):
start_positions
,
end_positions
,
start_logits
,
end_logits
,
loss_factor
=
loss_factor
)
end_logits
)
return
_loss_fn
...
...
@@ -249,14 +246,6 @@ def train_squad(strategy,
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
return
squad_model
,
core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn
=
get_loss_fn
(
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
)
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
...
...
@@ -269,7 +258,7 @@ def train_squad(strategy,
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
model_fn
=
_get_squad_model
,
loss_fn
=
loss_fn
,
loss_fn
=
get_
loss_fn
()
,
model_dir
=
FLAGS
.
model_dir
,
steps_per_epoch
=
steps_per_epoch
,
steps_per_loop
=
FLAGS
.
steps_per_loop
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment