Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
238321c0
Commit
238321c0
authored
Mar 04, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 298817372
parent
75d13042
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
9 deletions
+61
-9
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+61
-9
No files found.
official/nlp/bert/run_classifier.py
View file @
238321c0
...
@@ -239,22 +239,74 @@ def run_keras_compile_fit(model_dir,
...
@@ -239,22 +239,74 @@ def run_keras_compile_fit(model_dir,
return
bert_model
return
bert_model
def
get_predictions_and_labels
(
strategy
,
trained_model
,
eval_input_fn
,
eval_steps
):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@
tf
.
function
def
test_step
(
iterator
):
"""Computes predictions on distributed devices."""
def
_test_step_fn
(
inputs
):
"""Replicated predictions."""
inputs
,
labels
=
inputs
model_outputs
=
trained_model
(
inputs
,
training
=
False
)
return
model_outputs
,
labels
outputs
,
labels
=
strategy
.
experimental_run_v2
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
# outputs: current batch logits as a tuple of shard logits
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
)
labels
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
labels
)
return
outputs
,
labels
def
_run_evaluation
(
test_iterator
):
"""Runs evaluation steps."""
preds
,
golds
=
list
(),
list
()
for
_
in
range
(
eval_steps
):
logits
,
labels
=
test_step
(
test_iterator
)
for
cur_logits
,
cur_labels
in
zip
(
logits
,
labels
):
preds
.
extend
(
tf
.
math
.
argmax
(
cur_logits
,
axis
=
1
).
numpy
())
golds
.
extend
(
cur_labels
.
numpy
().
tolist
())
return
preds
,
golds
test_iter
=
iter
(
strategy
.
experimental_distribute_datasets_from_function
(
eval_input_fn
))
predictions
,
labels
=
_run_evaluation
(
test_iter
)
return
predictions
,
labels
def
export_classifier
(
model_export_path
,
input_meta_data
,
def
export_classifier
(
model_export_path
,
input_meta_data
,
restore_model_using_load_weights
,
restore_model_using_load_weights
,
bert_config
,
model_dir
):
bert_config
,
model_dir
):
"""Exports a trained model as a `SavedModel` for inference.
"""Exports a trained model as a `SavedModel` for inference.
Args:
Args:
model_export_path: a string specifying the path to the SavedModel directory.
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
for custom checkpoint or to use model.load_weights() API.
There are 2
There are 2
different ways to save checkpoints. One is using
different ways to save checkpoints. One is using
tf.train.Checkpoint and
tf.train.Checkpoint and
another is using Keras model.save_weights().
another is using Keras model.save_weights().
Custom training loop
Custom training loop
implementation uses tf.train.Checkpoint API
implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint
and Keras ModelCheckpoint
callback internally uses model.save_weights()
callback internally uses model.save_weights()
API. Since these two API's
API. Since these two API's
cannot be used together, model loading logic
cannot be used together, model loading logic
must be take into account how
must be take into account how
model checkpoint was saved.
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
summaries are stored.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment