Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7207422d
Commit
7207422d
authored
Oct 08, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 273653001
parent
dc93d9e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
17 deletions
+87
-17
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+4
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+83
-17
No files found.
official/nlp/bert/common_flags.py
View file @
7207422d
...
@@ -56,6 +56,10 @@ def define_common_bert_flags():
...
@@ -56,6 +56,10 @@ def define_common_bert_flags():
'scale_loss'
,
False
,
'scale_loss'
,
False
,
'Whether to divide the loss by number of replica inside the per-replica '
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.'
)
'loss function.'
)
flags
.
DEFINE_boolean
(
'use_keras_compile_fit'
,
False
,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.'
)
# Adds flags for mixed precision training.
# Adds flags for mixed precision training.
flags_core
.
define_performance
(
flags_core
.
define_performance
(
...
...
official/nlp/bert/run_classifier.py
View file @
7207422d
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
functools
import
functools
import
json
import
json
import
math
import
math
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
...
@@ -82,19 +83,19 @@ def get_loss_fn(num_classes, loss_factor=1.0):
...
@@ -82,19 +83,19 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return
classification_loss_fn
return
classification_loss_fn
def
run_
customized_training
(
strategy
,
def
run_
bert_classifier
(
strategy
,
bert_config
,
bert_config
,
input_meta_data
,
input_meta_data
,
model_dir
,
model_dir
,
epochs
,
epochs
,
steps_per_epoch
,
steps_per_epoch
,
steps_per_loop
,
steps_per_loop
,
eval_steps
,
eval_steps
,
warmup_steps
,
warmup_steps
,
initial_lr
,
initial_lr
,
init_checkpoint
,
init_checkpoint
,
custom_callbacks
=
None
,
custom_callbacks
=
None
,
run_eagerly
=
False
):
run_eagerly
=
False
):
"""Run BERT classifier training using low-level API."""
"""Run BERT classifier training using low-level API."""
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
[
'num_labels'
]
num_classes
=
input_meta_data
[
'num_labels'
]
...
@@ -144,6 +145,27 @@ def run_customized_training(strategy,
...
@@ -144,6 +145,27 @@ def run_customized_training(strategy,
return
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
return
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'test_accuracy'
,
dtype
=
tf
.
float32
)
'test_accuracy'
,
dtype
=
tf
.
float32
)
if
FLAGS
.
use_keras_compile_fit
:
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.'
)
return
run_keras_compile_fit
(
model_dir
,
strategy
,
_get_classifier_model
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
eval_steps
,
custom_callbacks
=
None
)
# Use user-defined loop to start training.
logging
.
info
(
'Training using customized training loop TF 2.0 with '
'distrubuted strategy.'
)
return
model_training_utils
.
run_customized_training_loop
(
return
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
strategy
=
strategy
,
model_fn
=
_get_classifier_model
,
model_fn
=
_get_classifier_model
,
...
@@ -161,6 +183,52 @@ def run_customized_training(strategy,
...
@@ -161,6 +183,52 @@ def run_customized_training(strategy,
run_eagerly
=
run_eagerly
)
run_eagerly
=
run_eagerly
)
def
run_keras_compile_fit
(
model_dir
,
strategy
,
model_fn
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
eval_steps
,
custom_callbacks
=
None
):
"""Runs BERT classifier model using Keras compile/fit API."""
with
strategy
.
scope
():
training_dataset
=
train_input_fn
()
evaluation_dataset
=
eval_input_fn
()
bert_model
,
sub_model
=
model_fn
()
optimizer
=
bert_model
.
optimizer
if
init_checkpoint
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
)
checkpoint
.
restore
(
init_checkpoint
).
assert_existing_objects_matched
()
bert_model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
[
metric_fn
()])
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
model_dir
)
checkpoint_dir
=
os
.
path
.
join
(
model_dir
,
'model_checkpoint.{epoch:02d}'
)
checkpoint_callback
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
checkpoint_dir
)
if
custom_callbacks
is
not
None
:
custom_callbacks
+=
[
summary_callback
,
checkpoint_callback
]
else
:
custom_callbacks
=
[
summary_callback
,
checkpoint_callback
]
bert_model
.
fit
(
x
=
training_dataset
,
validation_data
=
evaluation_dataset
,
steps_per_epoch
=
steps_per_epoch
,
epochs
=
epochs
,
validation_steps
=
eval_steps
,
callbacks
=
custom_callbacks
)
return
bert_model
def
export_classifier
(
model_export_path
,
input_meta_data
):
def
export_classifier
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
"""Exports a trained model as a `SavedModel` for inference.
...
@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
...
@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
if
not
strategy
:
if
not
strategy
:
raise
ValueError
(
'Distribution strategy has not been specified.'
)
raise
ValueError
(
'Distribution strategy has not been specified.'
)
# Runs customized training loop.
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
trained_model
=
run_bert_classifier
(
'strategy.'
)
trained_model
=
run_customized_training
(
strategy
,
strategy
,
bert_config
,
bert_config
,
input_meta_data
,
input_meta_data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment