Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a009f4fb
Commit
a009f4fb
authored
Sep 03, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 03, 2019
Browse files
move collection trainable variables outside loop.
add a flag to control loss scaling. PiperOrigin-RevId: 267091566
parent
a85c40e3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
12 deletions
+25
-12
official/bert/common_flags.py
official/bert/common_flags.py
+4
-0
official/bert/model_training_utils.py
official/bert/model_training_utils.py
+7
-5
official/bert/run_classifier.py
official/bert/run_classifier.py
+6
-3
official/bert/run_pretraining.py
official/bert/run_pretraining.py
+5
-3
official/bert/run_squad.py
official/bert/run_squad.py
+3
-1
No files found.
official/bert/common_flags.py
View file @
a009f4fb
...
@@ -52,6 +52,10 @@ def define_common_bert_flags():
...
@@ -52,6 +52,10 @@ def define_common_bert_flags():
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
'run_eagerly'
,
False
,
'run_eagerly'
,
False
,
'Run the model op by op without building a model function.'
)
'Run the model op by op without building a model function.'
)
flags
.
DEFINE_boolean
(
'scale_loss'
,
False
,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.'
)
# Adds flags for mixed precision training.
# Adds flags for mixed precision training.
flags_core
.
define_performance
(
flags_core
.
define_performance
(
...
...
official/bert/model_training_utils.py
View file @
a009f4fb
...
@@ -231,6 +231,10 @@ def run_customized_training_loop(
...
@@ -231,6 +231,10 @@ def run_customized_training_loop(
else
:
else
:
train_summary_writer
=
None
train_summary_writer
=
None
# De-dupes variables due to keras tracking issues.
training_vars
=
list
({
id
(
v
):
v
for
v
in
model
.
trainable_variables
}.
values
())
def
_replicated_step
(
inputs
):
def
_replicated_step
(
inputs
):
"""Replicated training step."""
"""Replicated training step."""
...
@@ -241,14 +245,12 @@ def run_customized_training_loop(
...
@@ -241,14 +245,12 @@ def run_customized_training_loop(
if
use_float16
:
if
use_float16
:
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
# De-dupes variables due to keras tracking issues.
tvars
=
list
({
id
(
v
):
v
for
v
in
model
.
trainable_variables
}.
values
())
if
use_float16
:
if
use_float16
:
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
t
raining_
vars
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
else
:
else
:
grads
=
tape
.
gradient
(
loss
,
tvars
)
grads
=
tape
.
gradient
(
loss
,
t
raining_
vars
)
optimizer
.
apply_gradients
(
zip
(
grads
,
tvars
))
optimizer
.
apply_gradients
(
zip
(
grads
,
t
raining_
vars
))
# For reporting, the metric takes the mean of losses.
# For reporting, the metric takes the mean of losses.
train_loss_metric
.
update_state
(
loss
)
train_loss_metric
.
update_state
(
loss
)
for
metric
in
train_metrics
:
for
metric
in
train_metrics
:
...
...
official/bert/run_classifier.py
View file @
a009f4fb
...
@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
...
@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
get_loss_fn
(
num_classes
,
loss_
scale
=
1.0
):
def
get_loss_fn
(
num_classes
,
loss_
factor
=
1.0
):
"""Gets the classification loss function."""
"""Gets the classification loss function."""
def
classification_loss_fn
(
labels
,
logits
):
def
classification_loss_fn
(
labels
,
logits
):
...
@@ -73,7 +73,7 @@ def get_loss_fn(num_classes, loss_scale=1.0):
...
@@ -73,7 +73,7 @@ def get_loss_fn(num_classes, loss_scale=1.0):
per_example_loss
=
-
tf
.
reduce_sum
(
per_example_loss
=
-
tf
.
reduce_sum
(
tf
.
cast
(
one_hot_labels
,
dtype
=
tf
.
float32
)
*
log_probs
,
axis
=-
1
)
tf
.
cast
(
one_hot_labels
,
dtype
=
tf
.
float32
)
*
log_probs
,
axis
=-
1
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
loss
*=
loss_
scale
loss
*=
loss_
factor
return
loss
return
loss
return
classification_loss_fn
return
classification_loss_fn
...
@@ -118,7 +118,10 @@ def run_customized_training(strategy,
...
@@ -118,7 +118,10 @@ def run_customized_training(strategy,
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
return
classifier_model
,
core_model
return
classifier_model
,
core_model
loss_fn
=
get_loss_fn
(
num_classes
,
loss_scale
=
1.0
)
loss_fn
=
get_loss_fn
(
num_classes
,
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
)
# Defines evaluation metrics function, which will create metrics in the
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
# correct device and strategy scope.
...
...
official/bert/run_pretraining.py
View file @
a009f4fb
...
@@ -94,11 +94,11 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
...
@@ -94,11 +94,11 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
return
_dataset_fn
if
use_dataset_fn
else
_dataset_fn
()
return
_dataset_fn
if
use_dataset_fn
else
_dataset_fn
()
def
get_loss_fn
(
loss_
scale
=
1.0
):
def
get_loss_fn
(
loss_
factor
=
1.0
):
"""Returns loss function for BERT pretraining."""
"""Returns loss function for BERT pretraining."""
def
_bert_pretrain_loss_fn
(
unused_labels
,
losses
,
**
unused_args
):
def
_bert_pretrain_loss_fn
(
unused_labels
,
losses
,
**
unused_args
):
return
tf
.
keras
.
backend
.
mean
(
losses
)
*
loss_
scale
return
tf
.
keras
.
backend
.
mean
(
losses
)
*
loss_
factor
return
_bert_pretrain_loss_fn
return
_bert_pretrain_loss_fn
...
@@ -132,7 +132,9 @@ def run_customized_training(strategy,
...
@@ -132,7 +132,9 @@ def run_customized_training(strategy,
trained_model
=
model_training_utils
.
run_customized_training_loop
(
trained_model
=
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
strategy
=
strategy
,
model_fn
=
_get_pretrain_model
,
model_fn
=
_get_pretrain_model
,
loss_fn
=
get_loss_fn
(),
loss_fn
=
get_loss_fn
(
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
),
model_dir
=
model_dir
,
model_dir
=
model_dir
,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
steps_per_epoch
=
steps_per_epoch
,
steps_per_epoch
=
steps_per_epoch
,
...
...
official/bert/run_squad.py
View file @
a009f4fb
...
@@ -232,7 +232,9 @@ def train_squad(strategy,
...
@@ -232,7 +232,9 @@ def train_squad(strategy,
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
# replica loss as it is.
loss_fn
=
get_loss_fn
(
loss_factor
=
1.0
)
loss_fn
=
get_loss_fn
(
loss_factor
=
1.0
/
strategy
.
num_replicas_in_sync
if
FLAGS
.
scale_loss
else
1.0
)
use_remote_tpu
=
(
FLAGS
.
strategy_type
==
'tpu'
and
FLAGS
.
tpu
)
use_remote_tpu
=
(
FLAGS
.
strategy_type
==
'tpu'
and
FLAGS
.
tpu
)
model_training_utils
.
run_customized_training_loop
(
model_training_utils
.
run_customized_training_loop
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment