Commit 29b4a322 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Refactor multitask evaluator: consume a list of tasks and optional dictionary of eval steps.

PiperOrigin-RevId: 386654855
parent d3b705d2
...@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config): ...@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
init_checkpoint: str = "" init_checkpoint: str = ""
model: base_config.Config = None model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
name: Optional[str] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -23,6 +23,7 @@ from official.modeling import hyperparams ...@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@dataclasses.dataclass @dataclasses.dataclass
class TaskRoutine(hyperparams.Config): class TaskRoutine(hyperparams.Config):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name: str = "" task_name: str = ""
task_config: cfg.TaskConfig = None task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None eval_steps: Optional[int] = None
...@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig): ...@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes: Attributes:
eval_tasks: individual evaluation tasks. eval_tasks: individual evaluation tasks.
""" """
eval_tasks: MultiTaskConfig = MultiTaskConfig() eval_tasks: Tuple[TaskRoutine, ...] = ()
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface. The evaluator implements the Orbit `AbstractEvaluator` interface.
""" """
from typing import Optional, Union from typing import Dict, List, Optional, Union
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task
from official.core import train_utils from official.core import train_utils
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable @gin.configurable
...@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def __init__( def __init__(
self, self,
task: multitask.MultiTask, eval_tasks: List[base_task.Task],
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel], model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None, global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None): checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
task: A multitask.MultiTask instance. eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance. model: tf.keras.Model instance.
global_step: the global step variable. global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint` checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface. interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._task = task self._tasks = eval_tasks
self._model = model self._model = model
self._global_step = global_step or orbit.utils.create_global_step() self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint( self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, global_step=self.global_step, model=self.model)
model=self.model)
self._validation_losses = None self._validation_losses = None
self._validation_metrics = None self._validation_metrics = None
# Builds per-task datasets. # Builds per-task datasets.
self.eval_datasets = {} self.eval_datasets = {}
for name, task in self.task.tasks.items(): self.eval_steps = eval_steps or {}
self.eval_datasets[name] = orbit.utils.make_distributed_dataset( for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data) self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops. # Builds per-task validation loops.
...@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return orbit.utils.create_loop_fn(eval_step_fn) return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = { self.task_fns = {
name: get_function(name, task) task.name: get_function(task.name, task) for task in self.tasks
for name, task in self.task.tasks.items()
} }
@property @property
...@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return self._strategy return self._strategy
@property @property
def task(self): def tasks(self):
return self._task return self._tasks
@property @property
def model(self): def model(self):
...@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_losses is None: if self._validation_losses is None:
# Builds the per-task metrics and losses. # Builds the per-task metrics and losses.
self._validation_losses = {} self._validation_losses = {}
for name in self.task.tasks: for task in self.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean( self._validation_losses[task.name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32) "validation_loss", dtype=tf.float32)
return self._validation_losses return self._validation_losses
...@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_metrics is None: if self._validation_metrics is None:
# Builds the per-task metrics and losses. # Builds the per-task metrics and losses.
self._validation_metrics = {} self._validation_metrics = {}
for name, task in self.task.tasks.items(): for task in self.tasks:
self._validation_metrics[name] = task.build_metrics(training=False) self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics return self._validation_metrics
@property @property
...@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results = {} results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets) eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items(): for task in self.tasks:
outputs = None outputs = None
name = task.name
eval_iter = eval_iters[name] eval_iter = eval_iters[name]
task = self.task.tasks[name] task_eval_steps = self.eval_steps.get(name, None) or num_steps
task_eval_steps = self.task.task_eval_steps(name) or num_steps outputs = self.task_fns[name](
outputs = task_eval_loop(
eval_iter, eval_iter,
task_eval_steps, task_eval_steps,
state=outputs, state=outputs,
......
...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations ...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations(): def all_strategy_combinations():
...@@ -89,9 +88,7 @@ class MockTask(base_task.Task): ...@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value])) np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state return state
def reduce_aggregated_logs(self, def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
aggregated_logs,
global_step=None):
for k, v in aggregated_logs.items(): for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0)) aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs return aggregated_logs
...@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo") MockTask(params=cfg.TaskConfig(), name="foo")
] ]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel() model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator( test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model) eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
...@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo") MockTask(params=cfg.TaskConfig(), name="foo")
] ]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel() model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator( test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model) eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"], self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync) 5. * distribution.num_replicas_in_sync)
......
...@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else: else:
raise ValueError("The tasks argument has an invalid type: %s" % raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks)) type(tasks))
self._task_eval_steps = task_eval_steps or {} self.task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_weights = task_weights or {} self._task_weights = task_weights or {}
self._task_weights = dict([ self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks (name, self._task_weights.get(name, 1.0)) for name in self.tasks
...@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps = {} task_eval_steps = {}
task_weights = {} task_weights = {}
for task_routine in config.task_routines: for task_routine in config.task_routines:
task_name = task_routine.task_name task_name = task_routine.task_name or task_routine.task_config.name
tasks[task_name] = task_factory.get_task( tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir) task_routine.task_config, logging_dir=logging_dir, name=task_name)
task_eval_steps[task_name] = task_routine.eval_steps task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight task_weights[task_name] = task_routine.task_weight
return cls( return cls(
...@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def tasks(self): def tasks(self):
return self._tasks return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name): def task_weight(self, task_name):
return self._task_weights[task_name] return self._task_weights[task_name]
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Multitask training driver library.""" """Multitask training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Optional from typing import List, Optional
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy, ...@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer = TRAINERS[params.trainer.trainer_type]( trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None **kwargs) if is_training else None
if is_eval: if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=task, eval_tasks=task.tasks.values(),
model=model, model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)) params, model_dir))
...@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval( ...@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*, *,
distribution_strategy: tf.distribute.Strategy, distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task, train_task: base_task.Task,
eval_tasks: multitask.MultiTask, eval_tasks: List[base_task.Task],
mode: str, mode: str,
params: configs.MultiEvalExperimentConfig, params: configs.MultiEvalExperimentConfig,
model_dir: str, model_dir: str,
...@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval( ...@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args: Args:
distribution_strategy: A distribution distribution_strategy. distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance. train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks. eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'. or 'continuous_eval'.
params: MultiEvalExperimentConfig instance. params: MultiEvalExperimentConfig instance.
...@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval( ...@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config=params, config=params,
task=train_task, task=train_task,
model=train_task.build_model(), model=train_task.build_model(),
optimizer=train_task.create_optimizer( optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
params.trainer.optimizer_config, params.runtime), params.runtime),
train=True, train=True,
evaluate=False) evaluate=False)
else: else:
...@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval( ...@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model = trainer.model if trainer else train_task.build_model() model = trainer.model if trainer else train_task.build_model()
if is_eval: if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks, eval_tasks=eval_tasks,
model=model, model=model,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)) params, model_dir))
else: else:
......
...@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task=configs.MultiTaskConfig( task=configs.MultiTaskConfig(
task_routines=( task_routines=(
configs.TaskRoutine( configs.TaskRoutine(
task_name='foo', task_name='foo', task_config=test_utils.FooConfig()),
task_config=test_utils.FooConfig()),
configs.TaskRoutine( configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig())))) task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
...@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig( experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(), task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig( eval_tasks=(configs.TaskRoutine(
task_routines=( task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
configs.TaskRoutine( configs.TaskRoutine(
task_name='foo', task_name='bar',
task_config=test_utils.FooConfig()), task_config=test_utils.BarConfig(),
configs.TaskRoutine( eval_steps=3)))
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False) experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope(): with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task) train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks) eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in experiment_config.eval_tasks
]
train_lib.run_experiment_with_multitask_eval( train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
train_task=train_task, train_task=train_task,
......
...@@ -28,7 +28,6 @@ from official.core import train_lib ...@@ -28,7 +28,6 @@ from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.modeling.multitask import configs from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib from official.modeling.multitask import train_lib as multitask_train_lib
...@@ -167,7 +166,10 @@ def run_continuous_finetune( ...@@ -167,7 +166,10 @@ def run_continuous_finetune(
with distribution_strategy.scope(): with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig): if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task) task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks) eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in params.eval_tasks
]
(_, (_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval( eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment