Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7207422d
Commit
7207422d
authored
Oct 08, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 273653001
parent
dc93d9e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
17 deletions
+87
-17
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+4
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+83
-17
No files found.
official/nlp/bert/common_flags.py
View file @
7207422d
...
...
@@ -56,6 +56,10 @@ def define_common_bert_flags():
'scale_loss'
,
False
,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.'
)
flags
.
DEFINE_boolean
(
'use_keras_compile_fit'
,
False
,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.'
)
# Adds flags for mixed precision training.
flags_core
.
define_performance
(
...
...
official/nlp/bert/run_classifier.py
View file @
7207422d
...
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
functools
import
json
import
math
import
os
from
absl
import
app
from
absl
import
flags
...
...
@@ -82,19 +83,19 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return
classification_loss_fn
def
run_
customized_training
(
strategy
,
bert_config
,
input_meta_data
,
model_dir
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
warmup_steps
,
initial_lr
,
init_checkpoint
,
custom_callbacks
=
None
,
run_eagerly
=
False
):
def
run_
bert_classifier
(
strategy
,
bert_config
,
input_meta_data
,
model_dir
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
warmup_steps
,
initial_lr
,
init_checkpoint
,
custom_callbacks
=
None
,
run_eagerly
=
False
):
"""Run BERT classifier training using low-level API."""
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
[
'num_labels'
]
...
...
@@ -144,6 +145,27 @@ def run_customized_training(strategy,
return
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'test_accuracy'
,
dtype
=
tf
.
float32
)
if
FLAGS
.
use_keras_compile_fit
:
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.'
)
return
run_keras_compile_fit
(
model_dir
,
strategy
,
_get_classifier_model
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
eval_steps
,
custom_callbacks
=
None
)
# Use user-defined loop to start training.
logging
.
info
(
'Training using customized training loop TF 2.0 with '
'distrubuted strategy.'
)
return
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
model_fn
=
_get_classifier_model
,
...
...
@@ -161,6 +183,52 @@ def run_customized_training(strategy,
run_eagerly
=
run_eagerly
)
def
run_keras_compile_fit
(
model_dir
,
strategy
,
model_fn
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
eval_steps
,
custom_callbacks
=
None
):
"""Runs BERT classifier model using Keras compile/fit API."""
with
strategy
.
scope
():
training_dataset
=
train_input_fn
()
evaluation_dataset
=
eval_input_fn
()
bert_model
,
sub_model
=
model_fn
()
optimizer
=
bert_model
.
optimizer
if
init_checkpoint
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
)
checkpoint
.
restore
(
init_checkpoint
).
assert_existing_objects_matched
()
bert_model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
[
metric_fn
()])
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
model_dir
)
checkpoint_dir
=
os
.
path
.
join
(
model_dir
,
'model_checkpoint.{epoch:02d}'
)
checkpoint_callback
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
checkpoint_dir
)
if
custom_callbacks
is
not
None
:
custom_callbacks
+=
[
summary_callback
,
checkpoint_callback
]
else
:
custom_callbacks
=
[
summary_callback
,
checkpoint_callback
]
bert_model
.
fit
(
x
=
training_dataset
,
validation_data
=
evaluation_dataset
,
steps_per_epoch
=
steps_per_epoch
,
epochs
=
epochs
,
validation_steps
=
eval_steps
,
callbacks
=
custom_callbacks
)
return
bert_model
def
export_classifier
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
...
...
@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
if
not
strategy
:
raise
ValueError
(
'Distribution strategy has not been specified.'
)
# Runs customized training loop.
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
'strategy.'
)
trained_model
=
run_customized_training
(
trained_model
=
run_bert_classifier
(
strategy
,
bert_config
,
input_meta_data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment