Commit 379f951d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix issue with processing partial batches for CoLa dataset.

PiperOrigin-RevId: 329012198
parent 45542535
...@@ -159,7 +159,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -159,7 +159,8 @@ class SentencePredictionTask(base_task.Task):
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'matthews_corrcoef':
logs.update({ logs.update({
'sentence_prediction': 'sentence_prediction':
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=0), # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels': 'labels':
labels, labels,
}) })
......
...@@ -86,7 +86,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -86,7 +86,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
iterator = iter(dataset) iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) return task.validation_step(next(iterator), model, metrics=metrics)
@parameterized.named_parameters( @parameterized.named_parameters(
("init_cls_pooler", True), ("init_cls_pooler", True),
...@@ -182,6 +182,34 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -182,6 +182,34 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs) aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
self.assertIn(metric_type, task.reduce_aggregated_logs(aggregated)) self.assertIn(metric_type, task.reduce_aggregated_logs(aggregated))
def test_np_metrics_cola_partial_batch(self):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
num_examples = 5
global_batch_size = 8
seq_length = 16
_create_fake_dataset(
train_data_path,
seq_length=seq_length,
num_classes=2,
num_examples=num_examples)
train_data_config = (
sentence_prediction_dataloader.SentencePredictionDataConfig(
input_path=train_data_path,
seq_length=seq_length,
is_training=True,
label_type="int",
global_batch_size=global_batch_size,
drop_remainder=False,
include_example_id=True))
config = sentence_prediction.SentencePredictionConfig(
metric_type="matthews_corrcoef",
model=self.get_model_config(2),
train_data=train_data_config)
outputs = self._run_task(config)
self.assertEqual(outputs["sentence_prediction"].shape.as_list(), [8, 1])
def test_task_with_fit(self): def test_task_with_fit(self):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
model=self.get_model_config(2), train_data=self._train_data_config) model=self.get_model_config(2), train_data=self._train_data_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment