Commit b3e4fefd authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Support numpy-based metrics through Orbit.

PiperOrigin-RevId: 317432167
parent b708fd68
...@@ -247,6 +247,14 @@ class Task(tf.Module): ...@@ -247,6 +247,14 @@ class Task(tf.Module):
"""Performs the forward step.""" """Performs the forward step."""
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step."""
pass
def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps."""
return {}
_REGISTERED_TASK_CLS = {} _REGISTERED_TASK_CLS = {}
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Sentence prediction (classification) task.""" """Sentence prediction (classification) task."""
import logging from absl import logging
import dataclasses import dataclasses
import numpy as np
from scipy import stats
from sklearn import metrics as sklearn_metrics
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
...@@ -33,6 +36,7 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -33,6 +36,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
# be specified. # be specified.
init_checkpoint: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
metric_type: str = 'accuracy'
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig( network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, # No masked language modeling head. num_masked_tokens=0, # No masked language modeling head.
cls_heads=[ cls_heads=[
...@@ -59,6 +63,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -59,6 +63,7 @@ class SentencePredictionTask(base_task.Task):
self._hub_module = hub.load(params.hub_module_url) self._hub_module = hub.load(params.hub_module_url)
else: else:
self._hub_module = None self._hub_module = None
self.metric_type = params.metric_type
def build_model(self): def build_model(self):
if self._hub_module: if self._hub_module:
...@@ -123,6 +128,57 @@ class SentencePredictionTask(base_task.Task): ...@@ -123,6 +128,57 @@ class SentencePredictionTask(base_task.Task):
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction']) compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
if self.metric_type == 'matthews_corrcoef':
return {
self.loss:
loss,
'sentence_prediction':
tf.expand_dims(
tf.math.argmax(outputs['sentence_prediction'], axis=1),
axis=0),
'labels':
labels,
}
if self.metric_type == 'pearson_spearman_corr':
return {
self.loss: loss,
'sentence_prediction': outputs['sentence_prediction'],
'labels': labels,
}
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
state = {'sentence_prediction': [], 'labels': []}
state['sentence_prediction'].append(
np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
axis=0))
state['labels'].append(
np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
return state
def reduce_aggregated_logs(self, aggregated_logs):
if self.metric_type == 'matthews_corrcoef':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
return {
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
}
if self.metric_type == 'pearson_spearman_corr':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
pearson_corr = stats.pearsonr(preds, labels)[0]
spearman_corr = stats.spearmanr(preds, labels)[0]
corr_metric = (pearson_corr + spearman_corr) / 2
return {self.metric_type: corr_metric}
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
"""Tests for official.nlp.tasks.sentence_prediction.""" """Tests for official.nlp.tasks.sentence_prediction."""
import functools import functools
import os import os
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs from official.nlp.bert import configs
...@@ -25,20 +27,24 @@ from official.nlp.configs import encoders ...@@ -25,20 +27,24 @@ from official.nlp.configs import encoders
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
class SentencePredictionTaskTest(tf.test.TestCase): class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(SentencePredictionTaskTest, self).setUp() super(SentencePredictionTaskTest, self).setUp()
self._network_config = bert.BertPretrainerConfig( self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def get_network_config(self, num_classes):
return bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig( encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1), vocab_size=30522, num_layers=1),
num_masked_tokens=0, num_masked_tokens=0,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="sentence_prediction") inner_dim=10,
num_classes=num_classes,
name="sentence_prediction")
]) ])
self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def _run_task(self, config): def _run_task(self, config):
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
...@@ -57,7 +63,7 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -57,7 +63,7 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
network=self._network_config, network=self.get_network_config(2),
train_data=self._train_data_config) train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model() model = task.build_model()
...@@ -84,12 +90,34 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -84,12 +90,34 @@ class SentencePredictionTaskTest(tf.test.TestCase):
ckpt.save(config.init_checkpoint) ckpt.save(config.init_checkpoint)
task.initialize(model) task.initialize(model)
def test_task_with_fit(self): @parameterized.parameters(("matthews_corrcoef", 2),
("pearson_spearman_corr", 1))
def test_np_metrics(self, metric_type, num_classes):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
network=self._network_config, metric_type=metric_type,
init_checkpoint=self.get_temp_dir(),
network=self.get_network_config(num_classes),
train_data=self._train_data_config) train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model() model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
strategy = tf.distribute.get_strategy()
distributed_outputs = strategy.run(
functools.partial(task.validation_step, model=model),
args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
aggregated = task.aggregate_logs(step_outputs=outputs)
aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
self.assertIn(metric_type, task.reduce_aggregated_logs(aggregated))
def test_task_with_fit(self):
config = sentence_prediction.SentencePredictionConfig(
network=self.get_network_config(2), train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
model = task.compile_model( model = task.compile_model(
model, model,
optimizer=tf.keras.optimizers.SGD(lr=0.1), optimizer=tf.keras.optimizers.SGD(lr=0.1),
...@@ -126,7 +154,7 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -126,7 +154,7 @@ class SentencePredictionTaskTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub() hub_module_url = self._export_bert_tfhub()
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
network=self._network_config, network=self.get_network_config(2),
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment