".github/vscode:/vscode.git/clone" did not exist on "1c4e4a34b702234ad4f9a41d8ac0c853cacde4d5"
Commit ca919737 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321100219
parent b0693846
...@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg ...@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass @dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig): class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction).""" """Data config for sentence prediction task (tasks/sentence_prediction)."""
...@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
global_batch_size: int = 32 global_batch_size: int = 32
is_training: bool = True is_training: bool = True
seq_length: int = 128 seq_length: int = 128
label_type: str = 'int'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
...@@ -42,11 +46,12 @@ class SentencePredictionDataLoader: ...@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
......
...@@ -31,6 +31,10 @@ from official.nlp.modeling import models ...@@ -31,6 +31,10 @@ from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
METRIC_TYPES = frozenset(
['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])
@dataclasses.dataclass @dataclasses.dataclass
class ModelConfig(base_config.Config): class ModelConfig(base_config.Config):
"""A classifier/regressor configuration.""" """A classifier/regressor configuration."""
...@@ -68,6 +72,9 @@ class SentencePredictionTask(base_task.Task): ...@@ -68,6 +72,9 @@ class SentencePredictionTask(base_task.Task):
self._hub_module = hub.load(params.hub_module_url) self._hub_module = hub.load(params.hub_module_url)
else: else:
self._hub_module = None self._hub_module = None
if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type self.metric_type = params.metric_type
def build_model(self): def build_model(self):
...@@ -77,7 +84,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -77,7 +84,7 @@ class SentencePredictionTask(base_task.Task):
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model.encoder) self.task_config.model.encoder)
# Currently, we only supports bert-style sentence prediction finetuning. # Currently, we only support bert-style sentence prediction finetuning.
return models.BertClassifier( return models.BertClassifier(
network=encoder_network, network=encoder_network,
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
...@@ -86,8 +93,11 @@ class SentencePredictionTask(base_task.Task): ...@@ -86,8 +93,11 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler) use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.sparse_categorical_crossentropy( if self.task_config.model.num_classes == 1:
labels, tf.cast(model_outputs, tf.float32), from_logits=True) loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
else:
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, tf.cast(model_outputs, tf.float32), from_logits=True)
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
...@@ -103,8 +113,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -103,8 +113,12 @@ class SentencePredictionTask(base_task.Task):
input_word_ids=dummy_ids, input_word_ids=dummy_ids,
input_mask=dummy_ids, input_mask=dummy_ids,
input_type_ids=dummy_ids) input_type_ids=dummy_ids)
y = tf.zeros((1, 1), dtype=tf.int32)
return (x, y) if self.task_config.model.num_classes == 1:
y = tf.zeros((1,), dtype=tf.float32)
else:
y = tf.zeros((1, 1), dtype=tf.int32)
return x, y
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat() dataset = dataset.repeat()
...@@ -116,7 +130,11 @@ class SentencePredictionTask(base_task.Task): ...@@ -116,7 +130,11 @@ class SentencePredictionTask(base_task.Task):
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')] if self.task_config.model.num_classes == 1:
metrics = [tf.keras.metrics.MeanSquaredError()]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
...@@ -154,6 +172,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -154,6 +172,7 @@ class SentencePredictionTask(base_task.Task):
return None return None
if state is None: if state is None:
state = {'sentence_prediction': [], 'labels': []} state = {'sentence_prediction': [], 'labels': []}
# TODO(b/160712818): Add support for concatenating partial batches.
state['sentence_prediction'].append( state['sentence_prediction'].append(
np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']], np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
axis=0)) axis=0))
...@@ -162,15 +181,21 @@ class SentencePredictionTask(base_task.Task): ...@@ -162,15 +181,21 @@ class SentencePredictionTask(base_task.Task):
return state return state
def reduce_aggregated_logs(self, aggregated_logs): def reduce_aggregated_logs(self, aggregated_logs):
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'accuracy':
return None
elif self.metric_type == 'matthews_corrcoef':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0) preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
preds = np.reshape(preds, -1)
labels = np.concatenate(aggregated_logs['labels'], axis=0) labels = np.concatenate(aggregated_logs['labels'], axis=0)
labels = np.reshape(labels, -1)
return { return {
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels) self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
} }
if self.metric_type == 'pearson_spearman_corr': elif self.metric_type == 'pearson_spearman_corr':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0) preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
preds = np.reshape(preds, -1)
labels = np.concatenate(aggregated_logs['labels'], axis=0) labels = np.concatenate(aggregated_logs['labels'], axis=0)
labels = np.reshape(labels, -1)
pearson_corr = stats.pearsonr(preds, labels)[0] pearson_corr = stats.pearsonr(preds, labels)[0]
spearman_corr = stats.spearmanr(preds, labels)[0] spearman_corr = stats.spearmanr(preds, labels)[0]
corr_metric = (pearson_corr + spearman_corr) / 2 corr_metric = (pearson_corr + spearman_corr) / 2
......
...@@ -85,6 +85,42 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -85,6 +85,42 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
ckpt.save(config.init_checkpoint) ckpt.save(config.init_checkpoint)
task.initialize(model) task.initialize(model)
@parameterized.named_parameters(
{
"testcase_name": "regression",
"num_classes": 1,
},
{
"testcase_name": "classification",
"num_classes": 2,
},
)
def test_metrics_and_losses(self, num_classes):
config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(),
model=self.get_model_config(num_classes),
train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
metrics = task.build_metrics()
if num_classes == 1:
self.assertIsInstance(metrics[0], tf.keras.metrics.MeanSquaredError)
else:
self.assertIsInstance(
metrics[0], tf.keras.metrics.SparseCategoricalAccuracy)
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
logs = task.validation_step(next(iterator), model, metrics=metrics)
loss = logs["loss"].numpy()
if num_classes == 1:
self.assertAlmostEqual(loss, 42.77483, places=3)
else:
self.assertAlmostEqual(loss, 3.57627e-6, places=3)
@parameterized.parameters(("matthews_corrcoef", 2), @parameterized.parameters(("matthews_corrcoef", 2),
("pearson_spearman_corr", 1)) ("pearson_spearman_corr", 1))
def test_np_metrics(self, metric_type, num_classes): def test_np_metrics(self, metric_type, num_classes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment