Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
adb61343
Commit
adb61343
authored
May 11, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
May 11, 2020
Browse files
Internal change
PiperOrigin-RevId: 311069693
parent
49b223b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
12 deletions
+33
-12
official/benchmark/bert_pretrain_benchmark.py
official/benchmark/bert_pretrain_benchmark.py
+4
-0
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+23
-12
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+6
-0
No files found.
official/benchmark/bert_pretrain_benchmark.py
View file @
adb61343
...
...
@@ -137,6 +137,10 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
'benchmark_accuracy_8x8_tpu_bf16_seq128_1m_steps'
)
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries/training_summary.txt'
)
# Set train_summary_interval to -1 to disable training summary, because
# writing summary to gcs may fail and summaries are not needed for this
# accuracy benchmark test.
FLAGS
.
train_summary_interval
=
-
1
self
.
_run_and_report_benchmark
(
summary_path
=
summary_path
,
report_accuracy
=
True
)
...
...
official/nlp/bert/model_training_utils.py
View file @
adb61343
...
...
@@ -89,6 +89,8 @@ def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
def
write_txt_summary
(
training_summary
,
summary_dir
):
"""Writes a summary text file to record stats."""
if
not
tf
.
io
.
gfile
.
exists
(
summary_dir
):
tf
.
io
.
gfile
.
mkdir
(
summary_dir
)
summary_path
=
os
.
path
.
join
(
summary_dir
,
_SUMMARY_TXT
)
with
tf
.
io
.
gfile
.
GFile
(
summary_path
,
'wb'
)
as
f
:
logging
.
info
(
'Training Summary:
\n
%s'
,
str
(
training_summary
))
...
...
@@ -117,7 +119,8 @@ def run_customized_training_loop(
sub_model_export_name
=
None
,
explicit_allreduce
=
False
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
):
post_allreduce_callbacks
=
None
,
train_summary_interval
=
0
):
"""Run BERT pretrain model training using low-level API.
Arguments:
...
...
@@ -181,6 +184,8 @@ def run_customized_training_loop(
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. Only used
when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled.
Returns:
Trained model.
...
...
@@ -272,13 +277,14 @@ def run_customized_training_loop(
summary_dir
=
tempfile
.
mkdtemp
()
eval_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
summary_dir
,
'eval'
))
if
steps_per_loop
>=
_MIN_SUMMARY_STEPS
:
last_summary_step
=
0
if
steps_per_loop
>=
_MIN_SUMMARY_STEPS
and
train_summary_interval
>=
0
:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
summary_dir
,
'train'
))
else
:
train_summary_writer
=
None
train_summary_writer
=
tf
.
summary
.
create_noop_writer
()
# Collects training variables.
training_vars
=
model
.
trainable_variables
...
...
@@ -438,15 +444,20 @@ def run_customized_training_loop(
training_status
=
'Train Step: %d/%d / loss = %s'
%
(
current_step
,
total_training_steps
,
train_loss
)
if
train_summary_writer
:
with
train_summary_writer
.
as_default
():
tf
.
summary
.
scalar
(
train_loss_metric
.
name
,
train_loss
,
step
=
current_step
)
for
metric
in
train_metrics
+
model
.
metrics
:
metric_value
=
_float_metric_value
(
metric
)
training_status
+=
' %s = %f'
%
(
metric
.
name
,
metric_value
)
tf
.
summary
.
scalar
(
metric
.
name
,
metric_value
,
step
=
current_step
)
train_summary_writer
.
flush
()
if
current_step
>=
last_summary_step
+
train_summary_interval
:
summary_writer
=
train_summary_writer
last_summary_step
=
current_step
else
:
summary_writer
=
tf
.
summary
.
create_noop_writer
()
with
summary_writer
.
as_default
():
tf
.
summary
.
scalar
(
train_loss_metric
.
name
,
train_loss
,
step
=
current_step
)
for
metric
in
train_metrics
+
model
.
metrics
:
metric_value
=
_float_metric_value
(
metric
)
training_status
+=
' %s = %f'
%
(
metric
.
name
,
metric_value
)
tf
.
summary
.
scalar
(
metric
.
name
,
metric_value
,
step
=
current_step
)
summary_writer
.
flush
()
logging
.
info
(
training_status
)
if
current_step
%
steps_per_epoch
==
0
:
...
...
official/nlp/bert/run_pretraining.py
View file @
adb61343
...
...
@@ -49,6 +49,9 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.'
)
flags
.
DEFINE_bool
(
'use_next_sentence_label'
,
True
,
'Whether to use next sentence label to compute final loss.'
)
flags
.
DEFINE_bool
(
'train_summary_interval'
,
0
,
'Step interval for training '
'summaries. If the value is a negative number, '
'then training summaries are not enabled.'
)
common_flags
.
define_common_bert_flags
()
...
...
@@ -101,6 +104,7 @@ def run_customized_training(strategy,
input_files
,
train_batch_size
,
use_next_sentence_label
=
True
,
train_summary_interval
=
0
,
custom_callbacks
=
None
):
"""Run BERT pretrain model training using low-level API."""
...
...
@@ -135,6 +139,7 @@ def run_customized_training(strategy,
steps_per_loop
=
steps_per_loop
,
epochs
=
epochs
,
sub_model_export_name
=
'pretrained/bert_model'
,
train_summary_interval
=
train_summary_interval
,
custom_callbacks
=
custom_callbacks
)
return
trained_model
...
...
@@ -170,6 +175,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
FLAGS
.
input_files
,
FLAGS
.
train_batch_size
,
FLAGS
.
use_next_sentence_label
,
FLAGS
.
train_summary_interval
,
custom_callbacks
=
custom_callbacks
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment