Commit 29b4a322 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Refactor multitask evaluator: consume a list of tasks and optional dictionary of eval steps.

PiperOrigin-RevId: 386654855
parent d3b705d2
......@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass
class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None
model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
name: Optional[str] = None
@dataclasses.dataclass
......
......@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(hyperparams.Config):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name: str = ""
task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None
......@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks: MultiTaskConfig = MultiTaskConfig()
eval_tasks: Tuple[TaskRoutine, ...] = ()
......@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Optional, Union
from typing import Dict, List, Optional, Union
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
......@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def __init__(
self,
task: multitask.MultiTask,
eval_tasks: List[base_task.Task],
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._task = task
self._tasks = eval_tasks
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model)
global_step=self.global_step, model=self.model)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
for name, task in self.task.tasks.items():
self.eval_datasets[name] = orbit.utils.make_distributed_dataset(
self.eval_steps = eval_steps or {}
for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
......@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
name: get_function(name, task)
for name, task in self.task.tasks.items()
task.name: get_function(task.name, task) for task in self.tasks
}
@property
......@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return self._strategy
@property
def task(self):
return self._task
def tasks(self):
return self._tasks
@property
def model(self):
......@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for name in self.task.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean(
for task in self.tasks:
self._validation_losses[task.name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
......@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for name, task in self.task.tasks.items():
self._validation_metrics[name] = task.build_metrics(training=False)
for task in self.tasks:
self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics
@property
......@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items():
for task in self.tasks:
outputs = None
name = task.name
eval_iter = eval_iters[name]
task = self.task.tasks[name]
task_eval_steps = self.task.task_eval_steps(name) or num_steps
outputs = task_eval_loop(
task_eval_steps = self.eval_steps.get(name, None) or num_steps
outputs = self.task_fns[name](
eval_iter,
task_eval_steps,
state=outputs,
......
......@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations():
......@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self,
aggregated_logs,
global_step=None):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
......@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
......@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync)
......
......@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else:
raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks))
self._task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self.task_eval_steps = task_eval_steps or {}
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
......@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
task_name = task_routine.task_name or task_routine.task_config.name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_routine.task_config, logging_dir=logging_dir, name=task_name)
task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight
return cls(
......@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def tasks(self):
return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
......
......@@ -15,7 +15,7 @@
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import Optional
from typing import List, Optional
from absl import logging
import orbit
import tensorflow as tf
......@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator(
task=task,
eval_tasks=task.tasks.values(),
model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
......@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task,
eval_tasks: multitask.MultiTask,
eval_tasks: List[base_task.Task],
mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str,
......@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks.
eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
......@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config=params,
task=train_task,
model=train_task.build_model(),
optimizer=train_task.create_optimizer(
params.trainer.optimizer_config, params.runtime),
optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
params.runtime),
train=True,
evaluate=False)
else:
......@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model = trainer.model if trainer else train_task.build_model()
if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks,
eval_tasks=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else:
......
......@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
task_name='foo', task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
......@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
eval_tasks=(configs.TaskRoutine(
task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
configs.TaskRoutine(
task_name='bar',
task_config=test_utils.BarConfig(),
eval_steps=3)))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks)
eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in experiment_config.eval_tasks
]
train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=train_task,
......
......@@ -28,7 +28,6 @@ from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib
......@@ -167,7 +166,10 @@ def run_continuous_finetune(
with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks)
eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in params.eval_tasks
]
(_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment