Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
533d1e6b
Commit
533d1e6b
authored
Mar 02, 2020
by
Will Cromar
Committed by
A. Unique TensorFlower
Mar 02, 2020
Browse files
Add TimeHistory callback to BERT.
PiperOrigin-RevId: 298466825
parent
7152763a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
45 additions
and
14 deletions
+45
-14
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+1
-1
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+2
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+12
-2
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+16
-1
official/utils/flags/_benchmark.py
official/utils/flags/_benchmark.py
+9
-5
official/utils/flags/core.py
official/utils/flags/core.py
+1
-0
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+4
-5
No files found.
official/modeling/model_training_utils.py
View file @
533d1e6b
...
...
@@ -368,8 +368,8 @@ def run_customized_training_loop(
train_steps
(
train_iterator
,
tf
.
convert_to_tensor
(
steps
,
dtype
=
tf
.
int32
))
train_loss
=
_float_metric_value
(
train_loss_metric
)
_run_callbacks_on_batch_end
(
current_step
,
{
'loss'
:
train_loss
})
current_step
+=
steps
_run_callbacks_on_batch_end
(
current_step
-
1
,
{
'loss'
:
train_loss
})
# Updates training logging.
training_status
=
'Train Step: %d/%d / loss = %s'
%
(
...
...
official/nlp/bert/common_flags.py
View file @
533d1e6b
...
...
@@ -69,6 +69,8 @@ def define_common_bert_flags():
flags
.
DEFINE_bool
(
'hub_module_trainable'
,
True
,
'True to make keras layers in the hub module trainable.'
)
flags_core
.
define_log_steps
()
# Adds flags for mixed precision and multi-worker training.
flags_core
.
define_performance
(
num_parallel_calls
=
False
,
...
...
official/nlp/bert/run_classifier.py
View file @
533d1e6b
...
...
@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs
,
steps_per_epoch
,
eval_steps
,
custom_callbacks
=
None
)
custom_callbacks
=
custom_callbacks
)
# Use user-defined loop to start training.
logging
.
info
(
'Training using customized training loop TF 2.0 with '
...
...
@@ -311,6 +311,15 @@ def run_bert(strategy,
if
not
strategy
:
raise
ValueError
(
'Distribution strategy has not been specified.'
)
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
,
)]
else
:
custom_callbacks
=
None
trained_model
=
run_bert_classifier
(
strategy
,
model_config
,
...
...
@@ -326,7 +335,8 @@ def run_bert(strategy,
train_input_fn
,
eval_input_fn
,
run_eagerly
=
FLAGS
.
run_eagerly
,
use_keras_compile_fit
=
FLAGS
.
use_keras_compile_fit
)
use_keras_compile_fit
=
FLAGS
.
use_keras_compile_fit
,
custom_callbacks
=
custom_callbacks
)
if
FLAGS
.
model_export_path
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
...
...
official/nlp/bert/run_squad.py
View file @
533d1e6b
...
...
@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_string
(
'vocab_file'
,
None
,
...
...
@@ -94,7 +95,21 @@ def main(_):
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
mode
in
(
'train'
,
'train_and_predict'
):
train_squad
(
strategy
,
input_meta_data
,
run_eagerly
=
FLAGS
.
run_eagerly
)
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
,
)]
else
:
custom_callbacks
=
None
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
custom_callbacks
,
run_eagerly
=
FLAGS
.
run_eagerly
,
)
if
FLAGS
.
mode
in
(
'predict'
,
'train_and_predict'
):
predict_squad
(
strategy
,
input_meta_data
)
...
...
official/utils/flags/_benchmark.py
View file @
533d1e6b
...
...
@@ -23,6 +23,14 @@ from absl import flags
from
official.utils.flags._conventions
import
help_wrap
def
define_log_steps
():
flags
.
DEFINE_integer
(
name
=
"log_steps"
,
default
=
100
,
help
=
"Frequency with which to log timing information with TimeHistory."
)
return
[]
def
define_benchmark
(
benchmark_log_dir
=
True
,
bigquery_uploader
=
True
):
"""Register benchmarking flags.
...
...
@@ -52,11 +60,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"human consumption, and does not have any impact within "
"the system."
))
flags
.
DEFINE_integer
(
name
=
'log_steps'
,
default
=
100
,
help
=
'For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.'
)
define_log_steps
()
if
benchmark_log_dir
:
flags
.
DEFINE_string
(
...
...
official/utils/flags/core.py
View file @
533d1e6b
...
...
@@ -72,6 +72,7 @@ define_base = register_key_flags_in_core(_base.define_base)
# We have define_base_eager for compatibility, since it used to be a separate
# function from define_base.
define_base_eager
=
define_base
define_log_steps
=
register_key_flags_in_core
(
_benchmark
.
define_log_steps
)
define_benchmark
=
register_key_flags_in_core
(
_benchmark
.
define_benchmark
)
define_device
=
register_key_flags_in_core
(
_device
.
define_device
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
...
...
official/utils/misc/keras_utils.py
View file @
533d1e6b
...
...
@@ -23,8 +23,7 @@ import os
import
time
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow.core.protobuf
import
rewriter_config_pb2
import
tensorflow.compat.v2
as
tf
from
tensorflow.python
import
tf2
from
tensorflow.python.profiler
import
profiler_v2
as
profiler
...
...
@@ -118,7 +117,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
now
))
logging
.
info
(
"
TimeHistory: %.2f examples/second between steps %d and %d
"
,
'
TimeHistory: %.2f examples/second between steps %d and %d
'
,
examples_per_second
,
self
.
last_log_step
,
self
.
global_steps
)
if
self
.
summary_writer
:
...
...
@@ -209,8 +208,8 @@ def set_session_config(enable_eager=False,
if
enable_eager
:
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
else
:
sess
=
tf
.
Session
(
config
=
config
)
tf
.
keras
.
backend
.
set_session
(
sess
)
sess
=
tf
.
compat
.
v1
.
Session
(
config
=
config
)
tf
.
compat
.
v1
.
keras
.
backend
.
set_session
(
sess
)
def
get_config_proto_v1
(
enable_xla
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment