Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b39958cf
Commit
b39958cf
authored
Mar 11, 2020
by
Will Cromar
Committed by
A. Unique TensorFlower
Mar 11, 2020
Browse files
Add TimeHistory callback to BERT.
PiperOrigin-RevId: 300433601
parent
1792fb76
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
39 additions
and
6 deletions
+39
-6
official/benchmark/bert_benchmark_utils.py
official/benchmark/bert_benchmark_utils.py
+5
-0
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+1
-1
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+2
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+12
-2
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+16
-1
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+3
-2
No files found.
official/benchmark/bert_benchmark_utils.py
View file @
b39958cf
...
...
@@ -44,6 +44,11 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
self
.
batch_start_times
[
batch
]
=
time
.
time
()
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
# If there are multiple steps_per_loop, the end batch index will not be the
# same as the starting index. Use the last starting index instead.
if
batch
not
in
self
.
batch_start_times
:
batch
=
max
(
self
.
batch_start_times
.
keys
())
self
.
batch_stop_times
[
batch
]
=
time
.
time
()
def
get_examples_per_sec
(
self
,
batch_size
,
num_batches_to_skip
=
1
):
...
...
official/modeling/model_training_utils.py
View file @
b39958cf
...
...
@@ -419,8 +419,8 @@ def run_customized_training_loop(
train_steps
(
train_iterator
,
tf
.
convert_to_tensor
(
steps
,
dtype
=
tf
.
int32
))
train_loss
=
_float_metric_value
(
train_loss_metric
)
_run_callbacks_on_batch_end
(
current_step
,
{
'loss'
:
train_loss
})
current_step
+=
steps
_run_callbacks_on_batch_end
(
current_step
-
1
,
{
'loss'
:
train_loss
})
# Updates training logging.
training_status
=
'Train Step: %d/%d / loss = %s'
%
(
...
...
official/nlp/bert/common_flags.py
View file @
b39958cf
...
...
@@ -77,6 +77,8 @@ def define_common_bert_flags():
flags
.
DEFINE_bool
(
'hub_module_trainable'
,
True
,
'True to make keras layers in the hub module trainable.'
)
flags_core
.
define_log_steps
()
# Adds flags for mixed precision and multi-worker training.
flags_core
.
define_performance
(
num_parallel_calls
=
False
,
...
...
official/nlp/bert/run_classifier.py
View file @
b39958cf
...
...
@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs
,
steps_per_epoch
,
eval_steps
,
custom_callbacks
=
None
)
custom_callbacks
=
custom_callbacks
)
# Use user-defined loop to start training.
logging
.
info
(
'Training using customized training loop TF 2.0 with '
...
...
@@ -363,6 +363,15 @@ def run_bert(strategy,
if
not
strategy
:
raise
ValueError
(
'Distribution strategy has not been specified.'
)
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
,
)]
else
:
custom_callbacks
=
None
trained_model
=
run_bert_classifier
(
strategy
,
model_config
,
...
...
@@ -378,7 +387,8 @@ def run_bert(strategy,
train_input_fn
,
eval_input_fn
,
run_eagerly
=
FLAGS
.
run_eagerly
,
use_keras_compile_fit
=
FLAGS
.
use_keras_compile_fit
)
use_keras_compile_fit
=
FLAGS
.
use_keras_compile_fit
,
custom_callbacks
=
custom_callbacks
)
if
FLAGS
.
model_export_path
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
...
...
official/nlp/bert/run_squad.py
View file @
b39958cf
...
...
@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_string
(
'vocab_file'
,
None
,
...
...
@@ -94,7 +95,21 @@ def main(_):
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
mode
in
(
'train'
,
'train_and_predict'
):
train_squad
(
strategy
,
input_meta_data
,
run_eagerly
=
FLAGS
.
run_eagerly
)
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
,
)]
else
:
custom_callbacks
=
None
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
custom_callbacks
,
run_eagerly
=
FLAGS
.
run_eagerly
,
)
if
FLAGS
.
mode
in
(
'predict'
,
'train_and_predict'
):
predict_squad
(
strategy
,
input_meta_data
)
...
...
official/utils/misc/keras_utils.py
View file @
b39958cf
...
...
@@ -117,8 +117,9 @@ class TimeHistory(tf.keras.callbacks.Callback):
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
now
))
logging
.
info
(
'TimeHistory: %.2f examples/second between steps %d and %d'
,
examples_per_second
,
self
.
last_log_step
,
self
.
global_steps
)
'TimeHistory: %.2f seconds, %.2f examples/second between steps %d '
'and %d'
,
elapsed_time
,
examples_per_second
,
self
.
last_log_step
,
self
.
global_steps
)
if
self
.
summary_writer
:
with
self
.
summary_writer
.
as_default
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment