Commit 992a864b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388106758
parent e0ad9ff2
......@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler,
trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__(
super().__init__(
multi_task=multi_task,
multi_task_model=multi_task_model,
optimizer=optimizer,
......@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)
def train_loop_end(self):
"""Record loss and metric values per task."""
result = super().train_loop_end()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if 'total_loss' in result:
result.pop('total_loss')
return result
......@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertNotIn("total_loss", results)
@combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment