Commit 992a864b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388106758
parent e0ad9ff2
...@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
optimizer: tf.optimizers.Optimizer, optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler, task_sampler: sampler.TaskSampler,
trainer_options=None): trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__( super().__init__(
multi_task=multi_task, multi_task=multi_task,
multi_task_model=multi_task_model, multi_task_model=multi_task_model,
optimizer=optimizer, optimizer=optimizer,
...@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
self._task_train_step_map[name], args=(next(iterator_map[name]),)) self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1) self.task_step_counter(name).assign_add(1)
def train_loop_end(self):
"""Record loss and metric values per task."""
result = super().train_loop_end()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if 'total_loss' in result:
result.pop('total_loss')
return result
...@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
results["bar"].keys()) results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"], self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys()) results["foo"].keys())
self.assertNotIn("total_loss", results)
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution): def test_trainer_with_configs(self, distribution):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment