Commit 786346f3 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347480509
parent a167bf93
...@@ -301,7 +301,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -301,7 +301,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def step_fn(inputs): def step_fn(inputs):
logs = self.task.validation_step( logs = self.task.validation_step(
inputs, model=self.model, metrics=self.validation_metrics) inputs, model=self.model, metrics=self.validation_metrics)
self._validation_loss.update_state(logs[self.task.loss]) if self.task.loss in logs:
self._validation_loss.update_state(logs[self.task.loss])
return logs return logs
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),)) distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
...@@ -311,8 +312,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -311,8 +312,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def eval_end(self, aggregated_logs=None): def eval_end(self, aggregated_logs=None):
"""Processes evaluation results.""" """Processes evaluation results."""
logs = {} logs = {}
for metric in self.validation_metrics + [self.validation_loss]: for metric in self.validation_metrics:
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
if self.validation_loss.count.numpy() != 0:
logs[self.validation_loss.name] = self.validation_loss.result()
else:
# `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method.
logging.info("The task did not report validation loss.")
if aggregated_logs: if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs) metrics = self.task.reduce_aggregated_logs(aggregated_logs)
logs.update(metrics) logs.update(metrics)
......
...@@ -54,8 +54,8 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -54,8 +54,8 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
} }
}))) })))
def create_test_trainer(self, config, model_dir=None): def create_test_trainer(self, config, model_dir=None, task=None):
task = mock_task.MockTask(config.task, logging_dir=model_dir) task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer( trainer = trainer_lib.Trainer(
config, config,
...@@ -79,6 +79,25 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,6 +79,25 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
trainer = self.create_test_trainer(self._config) trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync) self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate_without_loss(self, distribution):
class MockTaskWithoutValidationLoss(mock_task.MockTask):
def validation_step(self, inputs, model, metrics=None):
# Disable validation loss.
logs = super().validation_step(inputs, model)
del logs[self.loss]
return logs
with distribution.scope():
task = MockTaskWithoutValidationLoss()
trainer = self.create_test_trainer(self._config, task=task)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertNotIn('validation_loss', logs)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
......
...@@ -212,7 +212,10 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -212,7 +212,10 @@ class QuestionAnsweringTask(base_task.Task):
input_context) input_context)
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training if not training:
# We cannot compute start/end_position_accuracy because start/end_position
# labels are not available in the validation dataset (b/173794928).
return []
# TODO(lehou): a list of metrics doesn't work the same as in compile/fit. # TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy( tf.keras.metrics.SparseCategoricalAccuracy(
...@@ -244,8 +247,9 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -244,8 +247,9 @@ class QuestionAnsweringTask(base_task.Task):
unique_ids = features.pop('unique_ids') unique_ids = features.pop('unique_ids')
model_outputs = self.inference_step(features, model) model_outputs = self.inference_step(features, model)
start_logits, end_logits = model_outputs start_logits, end_logits = model_outputs
# We cannot compute validation_loss here, because start/end_position
# labels are not available in the validation dataset (b/173794928).
logs = { logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids, 'unique_ids': unique_ids,
'start_logits': start_logits, 'start_logits': start_logits,
'end_logits': end_logits, 'end_logits': end_logits,
...@@ -293,8 +297,6 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -293,8 +297,6 @@ class QuestionAnsweringTask(base_task.Task):
if self.task_config.validation_data.version_2_with_negative: if self.task_config.validation_data.version_2_with_negative:
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions, eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions,
scores_diff) scores_diff)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = { eval_metrics = {
'exact_match': eval_metrics['final_exact'], 'exact_match': eval_metrics['final_exact'],
'exact_match_threshold': eval_metrics['final_exact_thresh'], 'exact_match_threshold': eval_metrics['final_exact_thresh'],
...@@ -305,8 +307,6 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -305,8 +307,6 @@ class QuestionAnsweringTask(base_task.Task):
} }
else: else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions) eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = { eval_metrics = {
'exact_match': eval_metrics['exact_match'], 'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1'] 'final_f1': eval_metrics['final_f1']
...@@ -417,7 +417,6 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask): ...@@ -417,7 +417,6 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
class_logits = model_outputs['class_logits'] class_logits = model_outputs['class_logits']
logs = { logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids, 'unique_ids': unique_ids,
'start_top_predictions': start_top_predictions, 'start_top_predictions': start_top_predictions,
'end_top_predictions': end_top_predictions, 'end_top_predictions': end_top_predictions,
......
...@@ -250,6 +250,7 @@ class XLNetQuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -250,6 +250,7 @@ class XLNetQuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
logs = task.aggregate_logs(step_outputs=logs) logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs) metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics) self.assertIn("final_f1", metrics)
self.assertNotIn("loss", metrics)
def test_task(self): def test_task(self):
config = question_answering.XLNetQuestionAnsweringConfig( config = question_answering.XLNetQuestionAnsweringConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment