Commit 6c2efff9 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 403471439
parent 3cc082ea
......@@ -34,7 +34,8 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None):
trainer_options=None,
train_datasets=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
......@@ -55,10 +56,11 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
global_step=self.global_step,
**checkpoint_items)
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
if train_datasets is None:
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
......
......@@ -15,7 +15,7 @@
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import List, Optional
from typing import Any, List, Optional, Tuple
from absl import logging
import orbit
import tensorflow as tf
......@@ -36,11 +36,16 @@ TRAINERS = {
}
def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel, mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str) -> base_model.MultiTaskBaseModel:
def run_experiment(
*,
distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel,
mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str,
trainer: base_trainer.MultiTaskBaseTrainer = None
) -> base_model.MultiTaskBaseModel:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -51,6 +56,8 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
......@@ -66,8 +73,9 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights)
kwargs.update(dict(task_sampler=sampler))
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if trainer is None:
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator(
......@@ -145,7 +153,7 @@ def run_experiment_with_multitask_eval(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model:
trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment