Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca2e6ae0
Commit
ca2e6ae0
authored
May 23, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 23, 2020
Browse files
Internal change
PiperOrigin-RevId: 312923051
parent
2222cefc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
71 deletions
+97
-71
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+1
-1
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+96
-70
No files found.
official/nlp/bert/bert_models.py
View file @
ca2e6ae0
...
@@ -297,7 +297,7 @@ def squad_model(bert_config,
...
@@ -297,7 +297,7 @@ def squad_model(bert_config,
def
classifier_model
(
bert_config
,
def
classifier_model
(
bert_config
,
num_labels
,
num_labels
,
max_seq_length
,
max_seq_length
=
None
,
final_layer_initializer
=
None
,
final_layer_initializer
=
None
,
hub_module_url
=
None
,
hub_module_url
=
None
,
hub_module_trainable
=
True
):
hub_module_trainable
=
True
):
...
...
official/nlp/bert/run_classifier.py
View file @
ca2e6ae0
...
@@ -37,22 +37,23 @@ from official.nlp.bert import model_training_utils
...
@@ -37,22 +37,23 @@ from official.nlp.bert import model_training_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_enum
(
flags
.
DEFINE_enum
(
'mode'
,
'train_and_eval'
,
[
'train_and_eval'
,
'export_only'
],
'mode'
,
'train_and_eval'
,
[
'train_and_eval'
,
'export_only'
,
'predict'
],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
'One of {"train_and_eval", "export_only"
, "predict"
}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.'
)
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
'restores the model to output predictions on the test set.'
)
flags
.
DEFINE_string
(
'train_data_path'
,
None
,
flags
.
DEFINE_string
(
'train_data_path'
,
None
,
'Path to training data for BERT classifier.'
)
'Path to training data for BERT classifier.'
)
flags
.
DEFINE_string
(
'eval_data_path'
,
None
,
flags
.
DEFINE_string
(
'eval_data_path'
,
None
,
'Path to evaluation data for BERT classifier.'
)
'Path to evaluation data for BERT classifier.'
)
# Model training specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'input_meta_data_path'
,
None
,
'input_meta_data_path'
,
None
,
'Path to file that contains meta data about input '
'Path to file that contains meta data about input '
'to be used for training and evaluation.'
)
'to be used for training and evaluation.'
)
flags
.
DEFINE_string
(
'predict_checkpoint_path'
,
None
,
'Path to the checkpoint for predictions.'
)
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
...
@@ -125,9 +126,10 @@ def run_bert_classifier(strategy,
...
@@ -125,9 +126,10 @@ def run_bert_classifier(strategy,
max_seq_length
,
max_seq_length
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
FLAGS
.
hub_module_trainable
))
hub_module_trainable
=
FLAGS
.
hub_module_trainable
))
optimizer
=
optimization
.
create_optimizer
(
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
,
steps_per_epoch
*
epochs
,
FLAGS
.
end_lr
,
FLAGS
.
optimizer_type
)
warmup_steps
,
FLAGS
.
end_lr
,
FLAGS
.
optimizer_type
)
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
optimizer
,
use_float16
=
common_flags
.
use_float16
(),
use_float16
=
common_flags
.
use_float16
(),
...
@@ -214,9 +216,14 @@ def run_keras_compile_fit(model_dir,
...
@@ -214,9 +216,14 @@ def run_keras_compile_fit(model_dir,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
)
checkpoint_path
=
os
.
path
.
join
(
model_dir
,
'checkpoint'
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
bert_model
,
optimizer
=
optimizer
)
checkpoint_callback
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint_path
,
save_weights_only
=
True
)
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
None
,
step_counter
=
optimizer
.
iterations
,
checkpoint_interval
=
0
)
checkpoint_callback
=
keras_utils
.
SimpleCheckpoint
(
checkpoint_manager
)
if
custom_callbacks
is
not
None
:
if
custom_callbacks
is
not
None
:
custom_callbacks
+=
[
summary_callback
,
checkpoint_callback
]
custom_callbacks
+=
[
summary_callback
,
checkpoint_callback
]
...
@@ -234,8 +241,10 @@ def run_keras_compile_fit(model_dir,
...
@@ -234,8 +241,10 @@ def run_keras_compile_fit(model_dir,
return
bert_model
return
bert_model
def
get_predictions_and_labels
(
strategy
,
trained_model
,
eval_input_fn
,
def
get_predictions_and_labels
(
strategy
,
eval_steps
):
trained_model
,
eval_input_fn
,
return_probs
=
False
):
"""Obtains predictions of trained model on evaluation data.
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
Note that list of labels is returned along with the predictions because the
...
@@ -245,7 +254,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
...
@@ -245,7 +254,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
strategy: Distribution strategy.
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation step
s.
return_probs: Whether to return probabilities of classe
s.
Returns:
Returns:
predictions: List of predictions.
predictions: List of predictions.
...
@@ -259,11 +268,11 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
...
@@ -259,11 +268,11 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
def
_test_step_fn
(
inputs
):
def
_test_step_fn
(
inputs
):
"""Replicated predictions."""
"""Replicated predictions."""
inputs
,
labels
=
inputs
inputs
,
labels
=
inputs
model_outputs
=
trained_model
(
inputs
,
training
=
False
)
logits
=
trained_model
(
inputs
,
training
=
False
)
return
model_outputs
,
labels
probabilities
=
tf
.
nn
.
softmax
(
logits
)
return
probabilities
,
labels
outputs
,
labels
=
strategy
.
run
(
outputs
,
labels
=
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
_test_step_fn
,
args
=
(
next
(
iterator
),))
# outputs: current batch logits as a tuple of shard logits
# outputs: current batch logits as a tuple of shard logits
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
)
outputs
)
...
@@ -273,11 +282,18 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
...
@@ -273,11 +282,18 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
def
_run_evaluation
(
test_iterator
):
def
_run_evaluation
(
test_iterator
):
"""Runs evaluation steps."""
"""Runs evaluation steps."""
preds
,
golds
=
list
(),
list
()
preds
,
golds
=
list
(),
list
()
for
_
in
range
(
eval_steps
):
try
:
logits
,
labels
=
test_step
(
test_iterator
)
with
tf
.
experimental
.
async_scope
():
for
cur_logits
,
cur_labels
in
zip
(
logits
,
labels
):
while
True
:
preds
.
extend
(
tf
.
math
.
argmax
(
cur_logits
,
axis
=
1
).
numpy
())
probabilities
,
labels
=
test_step
(
test_iterator
)
golds
.
extend
(
cur_labels
.
numpy
().
tolist
())
for
cur_probs
,
cur_labels
in
zip
(
probabilities
,
labels
):
if
return_probs
:
preds
.
extend
(
cur_probs
.
numpy
().
tolist
())
else
:
preds
.
extend
(
tf
.
math
.
argmax
(
cur_probs
,
axis
=
1
).
numpy
())
golds
.
extend
(
cur_labels
.
numpy
().
tolist
())
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
tf
.
experimental
.
async_clear_error
()
return
preds
,
golds
return
preds
,
golds
test_iter
=
iter
(
test_iter
=
iter
(
...
@@ -287,21 +303,13 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
...
@@ -287,21 +303,13 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
return
predictions
,
labels
return
predictions
,
labels
def
export_classifier
(
model_export_path
,
input_meta_data
,
def
export_classifier
(
model_export_path
,
input_meta_data
,
bert_config
,
restore_model_using_load_weights
,
bert_config
,
model_dir
):
model_dir
):
"""Exports a trained model as a `SavedModel` for inference.
"""Exports a trained model as a `SavedModel` for inference.
Args:
Args:
model_export_path: a string specifying the path to the SavedModel directory.
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API. There are 2
different ways to save checkpoints. One is using tf.train.Checkpoint and
another is using Keras model.save_weights(). Custom training loop
implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
callback internally uses model.save_weights() API. Since these two API's
cannot be used together, model loading logic must be take into account how
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
summaries are stored.
...
@@ -317,14 +325,10 @@ def export_classifier(model_export_path, input_meta_data,
...
@@ -317,14 +325,10 @@ def export_classifier(model_export_path, input_meta_data,
# Export uses float32 for now, even if training uses mixed precision.
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
[
'num_labels'
],
bert_config
,
input_meta_data
[
'num_labels'
])[
0
]
input_meta_data
[
'max_seq_length'
])[
0
]
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
model
=
classifier_model
,
checkpoint_dir
=
model_dir
,
restore_model_using_load_weights
=
restore_model_using_load_weights
)
def
run_bert
(
strategy
,
def
run_bert
(
strategy
,
...
@@ -335,17 +339,6 @@ def run_bert(strategy,
...
@@ -335,17 +339,6 @@ def run_bert(strategy,
init_checkpoint
=
None
,
init_checkpoint
=
None
,
custom_callbacks
=
None
):
custom_callbacks
=
None
):
"""Run BERT training."""
"""Run BERT training."""
if
FLAGS
.
mode
==
'export_only'
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
FLAGS
.
use_keras_compile_fit
,
model_config
,
FLAGS
.
model_dir
)
return
if
FLAGS
.
mode
!=
'train_and_eval'
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
# Enables XLA in Session Config. Should not be set for TPU.
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
...
@@ -364,10 +357,11 @@ def run_bert(strategy,
...
@@ -364,10 +357,11 @@ def run_bert(strategy,
custom_callbacks
=
[]
custom_callbacks
=
[]
if
FLAGS
.
log_steps
:
if
FLAGS
.
log_steps
:
custom_callbacks
.
append
(
keras_utils
.
TimeHistory
(
custom_callbacks
.
append
(
batch_size
=
FLAGS
.
train_batch_size
,
keras_utils
.
TimeHistory
(
log_steps
=
FLAGS
.
log_steps
,
batch_size
=
FLAGS
.
train_batch_size
,
logdir
=
FLAGS
.
model_dir
))
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
))
trained_model
=
run_bert_classifier
(
trained_model
=
run_bert_classifier
(
strategy
,
strategy
,
...
@@ -388,13 +382,8 @@ def run_bert(strategy,
...
@@ -388,13 +382,8 @@ def run_bert(strategy,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
)
if
FLAGS
.
model_export_path
:
if
FLAGS
.
model_export_path
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
FLAGS
.
model_export_path
,
FLAGS
.
model_export_path
,
model
=
trained_model
)
model
=
trained_model
,
restore_model_using_load_weights
=
FLAGS
.
use_keras_compile_fit
)
return
trained_model
return
trained_model
...
@@ -412,25 +401,62 @@ def custom_main(custom_callbacks=None):
...
@@ -412,25 +401,62 @@ def custom_main(custom_callbacks=None):
if
not
FLAGS
.
model_dir
:
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
FLAGS
.
model_dir
=
'/tmp/bert20/'
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
bert_config
,
FLAGS
.
model_dir
)
return
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
distribution_strategy
=
FLAGS
.
distribution_strategy
,
num_gpus
=
FLAGS
.
num_gpus
,
num_gpus
=
FLAGS
.
num_gpus
,
tpu_address
=
FLAGS
.
tpu
)
tpu_address
=
FLAGS
.
tpu
)
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
train_input_fn
=
get_dataset_fn
(
FLAGS
.
train_data_path
,
max_seq_length
,
FLAGS
.
train_batch_size
,
is_training
=
True
)
eval_input_fn
=
get_dataset_fn
(
eval_input_fn
=
get_dataset_fn
(
FLAGS
.
eval_data_path
,
FLAGS
.
eval_data_path
,
max_seq_length
,
input_meta_data
[
'
max_seq_length
'
]
,
FLAGS
.
eval_batch_size
,
FLAGS
.
eval_batch_size
,
is_training
=
False
)
is_training
=
False
)
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'predict'
:
run_bert
(
strategy
,
input_meta_data
,
bert_config
,
train_input_fn
,
with
strategy
.
scope
():
eval_input_fn
,
custom_callbacks
=
custom_callbacks
)
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
[
'num_labels'
])[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
_
=
get_predictions_and_labels
(
strategy
,
classifier_model
,
eval_input_fn
,
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
return
if
FLAGS
.
mode
!=
'train_and_eval'
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
train_input_fn
=
get_dataset_fn
(
FLAGS
.
train_data_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
train_batch_size
,
is_training
=
True
)
run_bert
(
strategy
,
input_meta_data
,
bert_config
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
)
def
main
(
_
):
def
main
(
_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment