Commit f31e7ef9 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add a predict method for sentence_prediction task.

PiperOrigin-RevId: 321112219
parent 79f3edb6
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Sentence prediction (classification) task.""" """Sentence prediction (classification) task."""
from typing import List, Union
from absl import logging from absl import logging
import dataclasses import dataclasses
import numpy as np import numpy as np
import orbit
from scipy import stats from scipy import stats
from sklearn import metrics as sklearn_metrics from sklearn import metrics as sklearn_metrics
import tensorflow as tf import tensorflow as tf
...@@ -223,3 +226,52 @@ class SentencePredictionTask(base_task.Task): ...@@ -223,3 +226,52 @@ class SentencePredictionTask(base_task.Task):
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def predict(task: SentencePredictionTask, params: cfg.DataConfig,
model: tf.keras.Model) -> List[Union[int, float]]:
"""Predicts on the input data.
Args:
task: A `SentencePredictionTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
Returns:
A list of predictions with length of `num_examples`. For regression task,
each element in the list is the predicted score; for classification task,
each element is the predicted class id.
"""
is_regression = task.task_config.model.num_classes == 1
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
outputs = task.inference_step(x, model)
if is_regression:
return outputs
else:
return tf.argmax(outputs, axis=-1)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
"""Concatenates model's outputs."""
for per_replica_batch_predictions in outputs:
state.extend(per_replica_batch_predictions)
return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset.
predictions = loop_fn(
iter(dataset), num_steps=-1, state=[], reduce_fn=reduce_fn)
return predictions
...@@ -18,6 +18,7 @@ import functools ...@@ -18,6 +18,7 @@ import functools
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs from official.nlp.bert import configs
...@@ -28,6 +29,35 @@ from official.nlp.data import sentence_prediction_dataloader ...@@ -28,6 +29,35 @@ from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
def _create_fake_dataset(output_path, seq_length, num_classes, num_examples):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
def create_float_feature(values):
return tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
for _ in range(num_examples):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(np.ones_like(input_ids))
features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
if num_classes == 1:
features["label_ids"] = create_float_feature([np.random.random()])
else:
features["label_ids"] = create_int_feature(
[np.random.random_integers(0, num_classes - 1, size=())])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
...@@ -189,6 +219,35 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -189,6 +219,35 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
@parameterized.named_parameters(("classification", 5), ("regression", 1))
def test_prediction(self, num_classes):
task_config = sentence_prediction.SentencePredictionConfig(
model=self.get_model_config(num_classes=num_classes),
train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(task_config)
model = task.build_model()
test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record")
seq_length = 16
num_examples = 100
_create_fake_dataset(
test_data_path,
seq_length=seq_length,
num_classes=num_classes,
num_examples=num_examples)
test_data_config = (
sentence_prediction_dataloader.SentencePredictionDataConfig(
input_path=test_data_path,
seq_length=seq_length,
is_training=False,
label_type="int" if num_classes > 1 else "float",
global_batch_size=16,
drop_remainder=False))
predictions = sentence_prediction.predict(task, test_data_config, model)
self.assertLen(predictions, num_examples)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment