Commit b3e4fefd authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Support numpy-based metrics through Orbit.

PiperOrigin-RevId: 317432167
parent b708fd68
......@@ -247,6 +247,14 @@ class Task(tf.Module):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step."""
pass
def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps."""
return {}
_REGISTERED_TASK_CLS = {}
......
......@@ -14,8 +14,11 @@
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
import logging
from absl import logging
import dataclasses
import numpy as np
from scipy import stats
from sklearn import metrics as sklearn_metrics
import tensorflow as tf
import tensorflow_hub as hub
......@@ -33,6 +36,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
# be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
metric_type: str = 'accuracy'
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, # No masked language modeling head.
cls_heads=[
......@@ -59,6 +63,7 @@ class SentencePredictionTask(base_task.Task):
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
self.metric_type = params.metric_type
def build_model(self):
if self._hub_module:
......@@ -123,6 +128,57 @@ class SentencePredictionTask(base_task.Task):
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
if self.metric_type == 'matthews_corrcoef':
return {
self.loss:
loss,
'sentence_prediction':
tf.expand_dims(
tf.math.argmax(outputs['sentence_prediction'], axis=1),
axis=0),
'labels':
labels,
}
if self.metric_type == 'pearson_spearman_corr':
return {
self.loss: loss,
'sentence_prediction': outputs['sentence_prediction'],
'labels': labels,
}
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
state = {'sentence_prediction': [], 'labels': []}
state['sentence_prediction'].append(
np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
axis=0))
state['labels'].append(
np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
return state
def reduce_aggregated_logs(self, aggregated_logs):
if self.metric_type == 'matthews_corrcoef':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
return {
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
}
if self.metric_type == 'pearson_spearman_corr':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
pearson_corr = stats.pearsonr(preds, labels)[0]
spearman_corr = stats.spearmanr(preds, labels)[0]
corr_metric = (pearson_corr + spearman_corr) / 2
return {self.metric_type: corr_metric}
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
......
......@@ -16,6 +16,8 @@
"""Tests for official.nlp.tasks.sentence_prediction."""
import functools
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.bert import configs
......@@ -25,20 +27,24 @@ from official.nlp.configs import encoders
from official.nlp.tasks import sentence_prediction
class SentencePredictionTaskTest(tf.test.TestCase):
class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(SentencePredictionTaskTest, self).setUp()
self._network_config = bert.BertPretrainerConfig(
self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def get_network_config(self, num_classes):
return bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="sentence_prediction")
inner_dim=10,
num_classes=num_classes,
name="sentence_prediction")
])
self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def _run_task(self, config):
task = sentence_prediction.SentencePredictionTask(config)
......@@ -57,7 +63,7 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def test_task(self):
config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(),
network=self._network_config,
network=self.get_network_config(2),
train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
......@@ -84,12 +90,34 @@ class SentencePredictionTaskTest(tf.test.TestCase):
ckpt.save(config.init_checkpoint)
task.initialize(model)
def test_task_with_fit(self):
@parameterized.parameters(("matthews_corrcoef", 2),
("pearson_spearman_corr", 1))
def test_np_metrics(self, metric_type, num_classes):
config = sentence_prediction.SentencePredictionConfig(
network=self._network_config,
metric_type=metric_type,
init_checkpoint=self.get_temp_dir(),
network=self.get_network_config(num_classes),
train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
strategy = tf.distribute.get_strategy()
distributed_outputs = strategy.run(
functools.partial(task.validation_step, model=model),
args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
aggregated = task.aggregate_logs(step_outputs=outputs)
aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
self.assertIn(metric_type, task.reduce_aggregated_logs(aggregated))
def test_task_with_fit(self):
config = sentence_prediction.SentencePredictionConfig(
network=self.get_network_config(2), train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(lr=0.1),
......@@ -126,7 +154,7 @@ class SentencePredictionTaskTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub()
config = sentence_prediction.SentencePredictionConfig(
hub_module_url=hub_module_url,
network=self._network_config,
network=self.get_network_config(2),
train_data=self._train_data_config)
self._run_task(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment