Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
adb61343
Commit
adb61343
authored
May 11, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
May 11, 2020
Browse files
Internal change
PiperOrigin-RevId: 311069693
parent
49b223b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
12 deletions
+33
-12
official/benchmark/bert_pretrain_benchmark.py
official/benchmark/bert_pretrain_benchmark.py
+4
-0
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+23
-12
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+6
-0
No files found.
official/benchmark/bert_pretrain_benchmark.py
View file @
adb61343
...
@@ -137,6 +137,10 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
...
@@ -137,6 +137,10 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
'benchmark_accuracy_8x8_tpu_bf16_seq128_1m_steps'
)
'benchmark_accuracy_8x8_tpu_bf16_seq128_1m_steps'
)
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries/training_summary.txt'
)
'summaries/training_summary.txt'
)
# Set train_summary_interval to -1 to disable training summary, because
# writing summary to gcs may fail and summaries are not needed for this
# accuracy benchmark test.
FLAGS
.
train_summary_interval
=
-
1
self
.
_run_and_report_benchmark
(
summary_path
=
summary_path
,
self
.
_run_and_report_benchmark
(
summary_path
=
summary_path
,
report_accuracy
=
True
)
report_accuracy
=
True
)
...
...
official/nlp/bert/model_training_utils.py
View file @
adb61343
...
@@ -89,6 +89,8 @@ def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
...
@@ -89,6 +89,8 @@ def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
def
write_txt_summary
(
training_summary
,
summary_dir
):
def
write_txt_summary
(
training_summary
,
summary_dir
):
"""Writes a summary text file to record stats."""
"""Writes a summary text file to record stats."""
if
not
tf
.
io
.
gfile
.
exists
(
summary_dir
):
tf
.
io
.
gfile
.
mkdir
(
summary_dir
)
summary_path
=
os
.
path
.
join
(
summary_dir
,
_SUMMARY_TXT
)
summary_path
=
os
.
path
.
join
(
summary_dir
,
_SUMMARY_TXT
)
with
tf
.
io
.
gfile
.
GFile
(
summary_path
,
'wb'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
summary_path
,
'wb'
)
as
f
:
logging
.
info
(
'Training Summary:
\n
%s'
,
str
(
training_summary
))
logging
.
info
(
'Training Summary:
\n
%s'
,
str
(
training_summary
))
...
@@ -117,7 +119,8 @@ def run_customized_training_loop(
...
@@ -117,7 +119,8 @@ def run_customized_training_loop(
sub_model_export_name
=
None
,
sub_model_export_name
=
None
,
explicit_allreduce
=
False
,
explicit_allreduce
=
False
,
pre_allreduce_callbacks
=
None
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
):
post_allreduce_callbacks
=
None
,
train_summary_interval
=
0
):
"""Run BERT pretrain model training using low-level API.
"""Run BERT pretrain model training using low-level API.
Arguments:
Arguments:
...
@@ -181,6 +184,8 @@ def run_customized_training_loop(
...
@@ -181,6 +184,8 @@ def run_customized_training_loop(
functions will be invoked in the list order and right before gradients
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. Only used
are applied to variables for updates. Default is no callbacks. Only used
when explicit_allreduce=True.
when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled.
Returns:
Returns:
Trained model.
Trained model.
...
@@ -272,13 +277,14 @@ def run_customized_training_loop(
...
@@ -272,13 +277,14 @@ def run_customized_training_loop(
summary_dir
=
tempfile
.
mkdtemp
()
summary_dir
=
tempfile
.
mkdtemp
()
eval_summary_writer
=
tf
.
summary
.
create_file_writer
(
eval_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
summary_dir
,
'eval'
))
os
.
path
.
join
(
summary_dir
,
'eval'
))
if
steps_per_loop
>=
_MIN_SUMMARY_STEPS
:
last_summary_step
=
0
if
steps_per_loop
>=
_MIN_SUMMARY_STEPS
and
train_summary_interval
>=
0
:
# Only writes summary when the stats are collected sufficiently over
# Only writes summary when the stats are collected sufficiently over
# enough steps.
# enough steps.
train_summary_writer
=
tf
.
summary
.
create_file_writer
(
train_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
summary_dir
,
'train'
))
os
.
path
.
join
(
summary_dir
,
'train'
))
else
:
else
:
train_summary_writer
=
None
train_summary_writer
=
tf
.
summary
.
create_noop_writer
()
# Collects training variables.
# Collects training variables.
training_vars
=
model
.
trainable_variables
training_vars
=
model
.
trainable_variables
...
@@ -438,15 +444,20 @@ def run_customized_training_loop(
...
@@ -438,15 +444,20 @@ def run_customized_training_loop(
training_status
=
'Train Step: %d/%d / loss = %s'
%
(
training_status
=
'Train Step: %d/%d / loss = %s'
%
(
current_step
,
total_training_steps
,
train_loss
)
current_step
,
total_training_steps
,
train_loss
)
if
train_summary_writer
:
if
current_step
>=
last_summary_step
+
train_summary_interval
:
with
train_summary_writer
.
as_default
():
summary_writer
=
train_summary_writer
tf
.
summary
.
scalar
(
last_summary_step
=
current_step
train_loss_metric
.
name
,
train_loss
,
step
=
current_step
)
else
:
for
metric
in
train_metrics
+
model
.
metrics
:
summary_writer
=
tf
.
summary
.
create_noop_writer
()
metric_value
=
_float_metric_value
(
metric
)
training_status
+=
' %s = %f'
%
(
metric
.
name
,
metric_value
)
with
summary_writer
.
as_default
():
tf
.
summary
.
scalar
(
metric
.
name
,
metric_value
,
step
=
current_step
)
tf
.
summary
.
scalar
(
train_summary_writer
.
flush
()
train_loss_metric
.
name
,
train_loss
,
step
=
current_step
)
for
metric
in
train_metrics
+
model
.
metrics
:
metric_value
=
_float_metric_value
(
metric
)
training_status
+=
' %s = %f'
%
(
metric
.
name
,
metric_value
)
tf
.
summary
.
scalar
(
metric
.
name
,
metric_value
,
step
=
current_step
)
summary_writer
.
flush
()
logging
.
info
(
training_status
)
logging
.
info
(
training_status
)
if
current_step
%
steps_per_epoch
==
0
:
if
current_step
%
steps_per_epoch
==
0
:
...
...
official/nlp/bert/run_pretraining.py
View file @
adb61343
...
@@ -49,6 +49,9 @@ flags.DEFINE_float('warmup_steps', 10000,
...
@@ -49,6 +49,9 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.'
)
'Warmup steps for Adam weight decay optimizer.'
)
flags
.
DEFINE_bool
(
'use_next_sentence_label'
,
True
,
flags
.
DEFINE_bool
(
'use_next_sentence_label'
,
True
,
'Whether to use next sentence label to compute final loss.'
)
'Whether to use next sentence label to compute final loss.'
)
flags
.
DEFINE_bool
(
'train_summary_interval'
,
0
,
'Step interval for training '
'summaries. If the value is a negative number, '
'then training summaries are not enabled.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
...
@@ -101,6 +104,7 @@ def run_customized_training(strategy,
...
@@ -101,6 +104,7 @@ def run_customized_training(strategy,
input_files
,
input_files
,
train_batch_size
,
train_batch_size
,
use_next_sentence_label
=
True
,
use_next_sentence_label
=
True
,
train_summary_interval
=
0
,
custom_callbacks
=
None
):
custom_callbacks
=
None
):
"""Run BERT pretrain model training using low-level API."""
"""Run BERT pretrain model training using low-level API."""
...
@@ -135,6 +139,7 @@ def run_customized_training(strategy,
...
@@ -135,6 +139,7 @@ def run_customized_training(strategy,
steps_per_loop
=
steps_per_loop
,
steps_per_loop
=
steps_per_loop
,
epochs
=
epochs
,
epochs
=
epochs
,
sub_model_export_name
=
'pretrained/bert_model'
,
sub_model_export_name
=
'pretrained/bert_model'
,
train_summary_interval
=
train_summary_interval
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
)
return
trained_model
return
trained_model
...
@@ -170,6 +175,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
...
@@ -170,6 +175,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
FLAGS
.
input_files
,
FLAGS
.
input_files
,
FLAGS
.
train_batch_size
,
FLAGS
.
train_batch_size
,
FLAGS
.
use_next_sentence_label
,
FLAGS
.
use_next_sentence_label
,
FLAGS
.
train_summary_interval
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment