Commit 6c2efff9 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 403471439
parent 3cc082ea
...@@ -34,7 +34,8 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer): ...@@ -34,7 +34,8 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
multi_task_model: Union[tf.keras.Model, multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel], base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer, optimizer: tf.optimizers.Optimizer,
trainer_options=None): trainer_options=None,
train_datasets=None):
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task self._multi_task = multi_task
self._multi_task_model = multi_task_model self._multi_task_model = multi_task_model
...@@ -55,6 +56,7 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer): ...@@ -55,6 +56,7 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
global_step=self.global_step, global_step=self.global_step,
**checkpoint_items) **checkpoint_items)
if train_datasets is None:
train_datasets = {} train_datasets = {}
for name, task in self.multi_task.tasks.items(): for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset( train_datasets[name] = orbit.utils.make_distributed_dataset(
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Multitask training driver library.""" """Multitask training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import List, Optional from typing import Any, List, Optional, Tuple
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -36,11 +36,16 @@ TRAINERS = { ...@@ -36,11 +36,16 @@ TRAINERS = {
} }
def run_experiment(*, distribution_strategy: tf.distribute.Strategy, def run_experiment(
*,
distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask, task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel, mode: str, model: base_model.MultiTaskBaseModel,
mode: str,
params: configs.MultiTaskExperimentConfig, params: configs.MultiTaskExperimentConfig,
model_dir: str) -> base_model.MultiTaskBaseModel: model_dir: str,
trainer: base_trainer.MultiTaskBaseTrainer = None
) -> base_model.MultiTaskBaseModel:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -51,6 +56,8 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy, ...@@ -51,6 +56,8 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
or 'continuous_eval'. or 'continuous_eval'.
params: ExperimentConfig instance. params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries. model_dir: A 'str', a path to store model checkpoints and summaries.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
Returns: Returns:
model: `base_model.MultiTaskBaseModel` instance. model: `base_model.MultiTaskBaseModel` instance.
...@@ -66,6 +73,7 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy, ...@@ -66,6 +73,7 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler, sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights) task.task_weights)
kwargs.update(dict(task_sampler=sampler)) kwargs.update(dict(task_sampler=sampler))
if trainer is None:
trainer = TRAINERS[params.trainer.trainer_type]( trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None **kwargs) if is_training else None
if is_eval: if is_eval:
...@@ -145,7 +153,7 @@ def run_experiment_with_multitask_eval( ...@@ -145,7 +153,7 @@ def run_experiment_with_multitask_eval(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model: trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment