"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "92f9df5dd7f1d7cbcd089406e297a59a1c29b4f9"
Commit 4728eb68 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to run prediction on question answering (SQuAD) task.

PiperOrigin-RevId: 324703765
parent 8f34962f
......@@ -17,8 +17,10 @@
import collections
import json
import os
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
import tensorflow_hub as hub
......@@ -84,6 +86,10 @@ class QuestionAnsweringTask(base_task.Task):
self._tf_record_input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params.validation_data))
def set_preprocessed_eval_input_path(self, eval_input_path):
"""Sets the path to the preprocessed eval data."""
self._tf_record_input_path = eval_input_path
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
......@@ -242,10 +248,6 @@ class QuestionAnsweringTask(base_task.Task):
step_outputs['end_logits']):
u_ids, s_logits, e_logits = (
unique_ids.numpy(), start_logits.numpy(), end_logits.numpy())
if u_ids.size == 1:
u_ids = [u_ids]
s_logits = [s_logits]
e_logits = [e_logits]
for values in zip(u_ids, s_logits, e_logits):
state.append(self.raw_aggregated_result(
unique_id=values[0],
......@@ -291,3 +293,46 @@ class QuestionAnsweringTask(base_task.Task):
eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']}
return eval_metrics
def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
model: tf.keras.Model):
"""Predicts on the input data.
Args:
task: A `QuestionAnsweringTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
Returns:
A tuple of `all_predictions`, `all_nbest` and `scores_diff`, which
are dict and can be written to json files including prediction json file,
nbest json file and null_odds json file.
"""
tf_record_input_path, eval_examples, eval_features = (
task._preprocess_eval_data(params)) # pylint: disable=protected-access
# `tf_record_input_path` will overwrite `params.input_path`,
# when `task.buid_inputs()` is called.
task.set_preprocessed_eval_input_path(tf_record_input_path)
def predict_step(inputs):
"""Replicated prediction calculation."""
return task.validation_step(inputs, model)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
aggregated_outputs = utils.predict(predict_step, task.aggregate_logs, dataset)
all_predictions, all_nbest, scores_diff = (
task.squad_lib.postprocess_output(
eval_examples,
eval_features,
aggregated_outputs,
task.task_config.n_best_size,
task.task_config.max_answer_length,
task.task_config.validation_data.do_lower_case,
version_2_with_negative=(params.version_2_with_negative),
null_score_diff_threshold=task.task_config.null_score_diff_threshold,
verbose=False))
return all_predictions, all_nbest, scores_diff
......@@ -81,6 +81,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
val_dataset = task.build_inputs(config.validation_data)
val_iterator = iter(val_dataset)
logs = task.validation_step(next(val_iterator), model, metrics=metrics)
# Mock that `logs` is from one replica.
logs = {x: (logs[x],) for x in logs}
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
......@@ -160,6 +162,27 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self._get_validation_data_config())
self._run_task(config)
@parameterized.named_parameters(("squad1", False), ("squad2", True))
def test_predict(self, version_2_with_negative):
validation_data = self._get_validation_data_config(
version_2_with_negative=version_2_with_negative)
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=validation_data)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
all_predictions, all_nbest, scores_diff = question_answering.predict(
task, validation_data, model)
self.assertLen(all_predictions, 1)
self.assertLen(all_nbest, 1)
if version_2_with_negative:
self.assertLen(scores_diff, 1)
else:
self.assertEmpty(scores_diff)
if __name__ == "__main__":
tf.test.main()
......@@ -245,34 +245,25 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
"""
is_regression = task.task_config.model.num_classes == 1
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
outputs = task.inference_step(x, model)
if is_regression:
return outputs
else:
return tf.argmax(outputs, axis=-1)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
def predict_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
outputs = task.inference_step(x, model)
if is_regression:
return outputs
else:
return tf.argmax(outputs, axis=-1)
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
if state is None:
state = {'predictions': []}
for per_replica_batch_predictions in outputs:
state.extend(per_replica_batch_predictions)
state['predictions'].extend(per_replica_batch_predictions)
return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset.
predictions = loop_fn(
iter(dataset), num_steps=-1, state=[], reduce_fn=reduce_fn)
return predictions
outputs = utils.predict(predict_step, aggregate_fn, dataset)
return outputs['predictions']
......@@ -232,30 +232,25 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
sentence id of the corresponding example.
"""
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, y = inputs
sentence_ids = x.pop('sentence_id')
outputs = task.inference_step(x, model)
predict_ids = outputs['predict_ids']
label_mask = tf.greater_equal(y, 0)
return dict(
predict_ids=predict_ids,
label_mask=label_mask,
sentence_ids=sentence_ids)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
def predict_step(inputs):
"""Replicated prediction calculation."""
x, y = inputs
sentence_ids = x.pop('sentence_id')
outputs = task.inference_step(x, model)
predict_ids = outputs['predict_ids']
label_mask = tf.greater_equal(y, 0)
return dict(
predict_ids=predict_ids,
label_mask=label_mask,
sentence_ids=sentence_ids)
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
cur_predict_ids, cur_sentence_ids = state
if state is None:
state = {'predict_ids': [], 'sentence_ids': []}
cur_predict_ids = state['predict_ids']
cur_sentence_ids = state['sentence_ids']
for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip(
outputs['predict_ids'], outputs['label_mask'],
outputs['sentence_ids']):
......@@ -269,12 +264,9 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
# Skip the padding label.
if tmp_label_mask[i]:
cur_predict_ids[-1].append(tmp_predict_ids[i])
return cur_predict_ids, cur_sentence_ids
return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset.
predict_ids, sentence_ids = loop_fn(
iter(dataset), num_steps=-1, state=([], []), reduce_fn=reduce_fn)
return predict_ids, sentence_ids
outputs = utils.predict(predict_step, aggregate_fn, dataset)
return outputs['predict_ids'], outputs['sentence_ids']
......@@ -14,6 +14,9 @@
# limitations under the License.
# ==============================================================================
"""Common utils for tasks."""
from typing import Any, Callable
import orbit
import tensorflow as tf
import tensorflow_hub as hub
......@@ -32,3 +35,34 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output])
def predict(predict_step_fn: Callable[[Any], Any],
aggregate_fn: Callable[[Any, Any], Any],
dataset: tf.data.Dataset):
"""Runs prediction.
Args:
predict_step_fn: A callable such as `def predict_step(inputs)`, where
`inputs` are input tensors.
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
`value` is the outputs from `predict_step_fn`.
dataset: A `tf.data.Dataset` object.
Returns:
The aggregated predictions.
"""
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
outputs = tf.distribute.get_strategy().run(
predict_step_fn, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
loop_fn = orbit.utils.create_loop_fn(predict_step)
# Set `num_steps` to -1 to exhaust the dataset.
outputs = loop_fn(
iter(dataset), num_steps=-1, state=None, reduce_fn=aggregate_fn) # pytype: disable=wrong-arg-types
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment