Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
238321c0
Commit
238321c0
authored
Mar 04, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 298817372
parent
75d13042
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
9 deletions
+61
-9
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+61
-9
No files found.
official/nlp/bert/run_classifier.py
View file @
238321c0
...
...
@@ -239,22 +239,74 @@ def run_keras_compile_fit(model_dir,
return
bert_model
def
get_predictions_and_labels
(
strategy
,
trained_model
,
eval_input_fn
,
eval_steps
):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@
tf
.
function
def
test_step
(
iterator
):
"""Computes predictions on distributed devices."""
def
_test_step_fn
(
inputs
):
"""Replicated predictions."""
inputs
,
labels
=
inputs
model_outputs
=
trained_model
(
inputs
,
training
=
False
)
return
model_outputs
,
labels
outputs
,
labels
=
strategy
.
experimental_run_v2
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
# outputs: current batch logits as a tuple of shard logits
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
)
labels
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
labels
)
return
outputs
,
labels
def
_run_evaluation
(
test_iterator
):
"""Runs evaluation steps."""
preds
,
golds
=
list
(),
list
()
for
_
in
range
(
eval_steps
):
logits
,
labels
=
test_step
(
test_iterator
)
for
cur_logits
,
cur_labels
in
zip
(
logits
,
labels
):
preds
.
extend
(
tf
.
math
.
argmax
(
cur_logits
,
axis
=
1
).
numpy
())
golds
.
extend
(
cur_labels
.
numpy
().
tolist
())
return
preds
,
golds
test_iter
=
iter
(
strategy
.
experimental_distribute_datasets_from_function
(
eval_input_fn
))
predictions
,
labels
=
_run_evaluation
(
test_iter
)
return
predictions
,
labels
def
export_classifier
(
model_export_path
,
input_meta_data
,
restore_model_using_load_weights
,
bert_config
,
model_dir
):
restore_model_using_load_weights
,
bert_config
,
model_dir
):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2
different ways to save checkpoints. One is using
tf.train.Checkpoint and
another is using Keras model.save_weights().
Custom training loop
implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint
callback internally uses model.save_weights()
API. Since these two API's
cannot be used together, model loading logic
must be take into account how
model checkpoint was saved.
for custom checkpoint or to use model.load_weights() API.
There are 2
different ways to save checkpoints. One is using
tf.train.Checkpoint and
another is using Keras model.save_weights().
Custom training loop
implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint
callback internally uses model.save_weights()
API. Since these two API's
cannot be used together, model loading logic
must be take into account how
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment