Commit e93afea8 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to run prediction on question answering (SQuAD) task.

PiperOrigin-RevId: 324703765
parent 7ebdee5f
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
import collections import collections
import json import json
import os import os
from absl import logging from absl import logging
import dataclasses import dataclasses
import orbit
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
...@@ -84,6 +86,10 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -84,6 +86,10 @@ class QuestionAnsweringTask(base_task.Task):
self._tf_record_input_path, self._eval_examples, self._eval_features = ( self._tf_record_input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params.validation_data)) self._preprocess_eval_data(params.validation_data))
def set_preprocessed_eval_input_path(self, eval_input_path):
"""Sets the path to the preprocessed eval data."""
self._tf_record_input_path = eval_input_path
def build_model(self): def build_model(self):
if self._hub_module: if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
...@@ -242,10 +248,6 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -242,10 +248,6 @@ class QuestionAnsweringTask(base_task.Task):
step_outputs['end_logits']): step_outputs['end_logits']):
u_ids, s_logits, e_logits = ( u_ids, s_logits, e_logits = (
unique_ids.numpy(), start_logits.numpy(), end_logits.numpy()) unique_ids.numpy(), start_logits.numpy(), end_logits.numpy())
if u_ids.size == 1:
u_ids = [u_ids]
s_logits = [s_logits]
e_logits = [e_logits]
for values in zip(u_ids, s_logits, e_logits): for values in zip(u_ids, s_logits, e_logits):
state.append(self.raw_aggregated_result( state.append(self.raw_aggregated_result(
unique_id=values[0], unique_id=values[0],
...@@ -291,3 +293,46 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -291,3 +293,46 @@ class QuestionAnsweringTask(base_task.Task):
eval_metrics = {'exact_match': eval_metrics['exact_match'], eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']} 'final_f1': eval_metrics['final_f1']}
return eval_metrics return eval_metrics
def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
model: tf.keras.Model):
"""Predicts on the input data.
Args:
task: A `QuestionAnsweringTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
Returns:
A tuple of `all_predictions`, `all_nbest` and `scores_diff`, which
are dict and can be written to json files including prediction json file,
nbest json file and null_odds json file.
"""
tf_record_input_path, eval_examples, eval_features = (
task._preprocess_eval_data(params)) # pylint: disable=protected-access
# `tf_record_input_path` will overwrite `params.input_path`,
# when `task.buid_inputs()` is called.
task.set_preprocessed_eval_input_path(tf_record_input_path)
def predict_step(inputs):
"""Replicated prediction calculation."""
return task.validation_step(inputs, model)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
aggregated_outputs = utils.predict(predict_step, task.aggregate_logs, dataset)
all_predictions, all_nbest, scores_diff = (
task.squad_lib.postprocess_output(
eval_examples,
eval_features,
aggregated_outputs,
task.task_config.n_best_size,
task.task_config.max_answer_length,
task.task_config.validation_data.do_lower_case,
version_2_with_negative=(params.version_2_with_negative),
null_score_diff_threshold=task.task_config.null_score_diff_threshold,
verbose=False))
return all_predictions, all_nbest, scores_diff
...@@ -81,6 +81,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -81,6 +81,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
val_dataset = task.build_inputs(config.validation_data) val_dataset = task.build_inputs(config.validation_data)
val_iterator = iter(val_dataset) val_iterator = iter(val_dataset)
logs = task.validation_step(next(val_iterator), model, metrics=metrics) logs = task.validation_step(next(val_iterator), model, metrics=metrics)
# Mock that `logs` is from one replica.
logs = {x: (logs[x],) for x in logs}
logs = task.aggregate_logs(step_outputs=logs) logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs) metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics) self.assertIn("final_f1", metrics)
...@@ -160,6 +162,27 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -160,6 +162,27 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self._get_validation_data_config()) validation_data=self._get_validation_data_config())
self._run_task(config) self._run_task(config)
@parameterized.named_parameters(("squad1", False), ("squad2", True))
def test_predict(self, version_2_with_negative):
validation_data = self._get_validation_data_config(
version_2_with_negative=version_2_with_negative)
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=validation_data)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
all_predictions, all_nbest, scores_diff = question_answering.predict(
task, validation_data, model)
self.assertLen(all_predictions, 1)
self.assertLen(all_nbest, 1)
if version_2_with_negative:
self.assertLen(scores_diff, 1)
else:
self.assertEmpty(scores_diff)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -245,34 +245,25 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig, ...@@ -245,34 +245,25 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
""" """
is_regression = task.task_config.model.num_classes == 1 is_regression = task.task_config.model.num_classes == 1
@tf.function def predict_step(inputs):
def predict_step(iterator): """Replicated prediction calculation."""
"""Predicts on distributed devices.""" x, _ = inputs
outputs = task.inference_step(x, model)
def _replicated_step(inputs): if is_regression:
"""Replicated prediction calculation.""" return outputs
x, _ = inputs else:
outputs = task.inference_step(x, model) return tf.argmax(outputs, axis=-1)
if is_regression:
return outputs def aggregate_fn(state, outputs):
else:
return tf.argmax(outputs, axis=-1)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
"""Concatenates model's outputs.""" """Concatenates model's outputs."""
if state is None:
state = {'predictions': []}
for per_replica_batch_predictions in outputs: for per_replica_batch_predictions in outputs:
state.extend(per_replica_batch_predictions) state['predictions'].extend(per_replica_batch_predictions)
return state return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(), dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params) task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset. outputs = utils.predict(predict_step, aggregate_fn, dataset)
predictions = loop_fn( return outputs['predictions']
iter(dataset), num_steps=-1, state=[], reduce_fn=reduce_fn)
return predictions
...@@ -232,30 +232,25 @@ def predict(task: TaggingTask, params: cfg.DataConfig, ...@@ -232,30 +232,25 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
sentence id of the corresponding example. sentence id of the corresponding example.
""" """
@tf.function def predict_step(inputs):
def predict_step(iterator): """Replicated prediction calculation."""
"""Predicts on distributed devices.""" x, y = inputs
sentence_ids = x.pop('sentence_id')
def _replicated_step(inputs): outputs = task.inference_step(x, model)
"""Replicated prediction calculation.""" predict_ids = outputs['predict_ids']
x, y = inputs label_mask = tf.greater_equal(y, 0)
sentence_ids = x.pop('sentence_id') return dict(
outputs = task.inference_step(x, model) predict_ids=predict_ids,
predict_ids = outputs['predict_ids'] label_mask=label_mask,
label_mask = tf.greater_equal(y, 0) sentence_ids=sentence_ids)
return dict(
predict_ids=predict_ids, def aggregate_fn(state, outputs):
label_mask=label_mask,
sentence_ids=sentence_ids)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
"""Concatenates model's outputs.""" """Concatenates model's outputs."""
cur_predict_ids, cur_sentence_ids = state if state is None:
state = {'predict_ids': [], 'sentence_ids': []}
cur_predict_ids = state['predict_ids']
cur_sentence_ids = state['sentence_ids']
for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip( for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip(
outputs['predict_ids'], outputs['label_mask'], outputs['predict_ids'], outputs['label_mask'],
outputs['sentence_ids']): outputs['sentence_ids']):
...@@ -269,12 +264,9 @@ def predict(task: TaggingTask, params: cfg.DataConfig, ...@@ -269,12 +264,9 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
# Skip the padding label. # Skip the padding label.
if tmp_label_mask[i]: if tmp_label_mask[i]:
cur_predict_ids[-1].append(tmp_predict_ids[i]) cur_predict_ids[-1].append(tmp_predict_ids[i])
return cur_predict_ids, cur_sentence_ids return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(), dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params) task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset. outputs = utils.predict(predict_step, aggregate_fn, dataset)
predict_ids, sentence_ids = loop_fn( return outputs['predict_ids'], outputs['sentence_ids']
iter(dataset), num_steps=-1, state=([], []), reduce_fn=reduce_fn)
return predict_ids, sentence_ids
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common utils for tasks.""" """Common utils for tasks."""
from typing import Any, Callable
import orbit
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
...@@ -32,3 +35,34 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model: ...@@ -32,3 +35,34 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
return tf.keras.Model( return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids], inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output]) outputs=[sequence_output, pooled_output])
def predict(predict_step_fn: Callable[[Any], Any],
aggregate_fn: Callable[[Any, Any], Any],
dataset: tf.data.Dataset):
"""Runs prediction.
Args:
predict_step_fn: A callable such as `def predict_step(inputs)`, where
`inputs` are input tensors.
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
`value` is the outputs from `predict_step_fn`.
dataset: A `tf.data.Dataset` object.
Returns:
The aggregated predictions.
"""
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
outputs = tf.distribute.get_strategy().run(
predict_step_fn, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
loop_fn = orbit.utils.create_loop_fn(predict_step)
# Set `num_steps` to -1 to exhaust the dataset.
outputs = loop_fn(
iter(dataset), num_steps=-1, state=None, reduce_fn=aggregate_fn) # pytype: disable=wrong-arg-types
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment