Commit 786346f3 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347480509
parent a167bf93
......@@ -301,6 +301,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def step_fn(inputs):
logs = self.task.validation_step(
inputs, model=self.model, metrics=self.validation_metrics)
if self.task.loss in logs:
self._validation_loss.update_state(logs[self.task.loss])
return logs
......@@ -311,8 +312,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def eval_end(self, aggregated_logs=None):
"""Processes evaluation results."""
logs = {}
for metric in self.validation_metrics + [self.validation_loss]:
for metric in self.validation_metrics:
logs[metric.name] = metric.result()
if self.validation_loss.count.numpy() != 0:
logs[self.validation_loss.name] = self.validation_loss.result()
else:
# `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method.
logging.info("The task did not report validation loss.")
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs)
logs.update(metrics)
......
......@@ -54,8 +54,8 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}
})))
def create_test_trainer(self, config, model_dir=None):
task = mock_task.MockTask(config.task, logging_dir=model_dir)
def create_test_trainer(self, config, model_dir=None, task=None):
task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config,
......@@ -79,6 +79,25 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate_without_loss(self, distribution):
class MockTaskWithoutValidationLoss(mock_task.MockTask):
def validation_step(self, inputs, model, metrics=None):
# Disable validation loss.
logs = super().validation_step(inputs, model)
del logs[self.loss]
return logs
with distribution.scope():
task = MockTaskWithoutValidationLoss()
trainer = self.create_test_trainer(self._config, task=task)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertNotIn('validation_loss', logs)
@combinations.generate(
combinations.combine(
......
......@@ -212,7 +212,10 @@ class QuestionAnsweringTask(base_task.Task):
input_context)
def build_metrics(self, training=None):
del training
if not training:
# We cannot compute start/end_position_accuracy because start/end_position
# labels are not available in the validation dataset (b/173794928).
return []
# TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(
......@@ -244,8 +247,9 @@ class QuestionAnsweringTask(base_task.Task):
unique_ids = features.pop('unique_ids')
model_outputs = self.inference_step(features, model)
start_logits, end_logits = model_outputs
# We cannot compute validation_loss here, because start/end_position
# labels are not available in the validation dataset (b/173794928).
logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids,
'start_logits': start_logits,
'end_logits': end_logits,
......@@ -293,8 +297,6 @@ class QuestionAnsweringTask(base_task.Task):
if self.task_config.validation_data.version_2_with_negative:
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions,
scores_diff)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = {
'exact_match': eval_metrics['final_exact'],
'exact_match_threshold': eval_metrics['final_exact_thresh'],
......@@ -305,8 +307,6 @@ class QuestionAnsweringTask(base_task.Task):
}
else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = {
'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']
......@@ -417,7 +417,6 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
class_logits = model_outputs['class_logits']
logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids,
'start_top_predictions': start_top_predictions,
'end_top_predictions': end_top_predictions,
......
......@@ -250,6 +250,7 @@ class XLNetQuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
self.assertNotIn("loss", metrics)
def test_task(self):
config = question_answering.XLNetQuestionAnsweringConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment