Commit 3b0d58e2 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add multitask evaluation

PiperOrigin-RevId: 350705651
parent 72284a6c
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
from typing import Text, Dict
import tensorflow as tf
class MultiTaskBaseModel(tf.Module):
"""Base class that holds multi-task model computation."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._sub_tasks = self._instantiate_sub_tasks()
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise NotImplementedError(
"_instantiate_sub_task_models() is not implemented.")
@property
def sub_tasks(self):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return self._sub_tasks
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class TaskRoutine(base_config.Config):
task_name: str = ""
task_config: cfg.TaskConfig = None
mixing_steps: int = 1
eval_steps: Optional[int] = None
task_weight: Optional[float] = None
@dataclasses.dataclass
class MultiTaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class MultiEvalExperimentConfig(base_config.Config):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
task: the single-stream training task.
eval_tasks: individual evaluation tasks.
trainer: the trainer configuration.
runtime: the runtime configuration.
"""
task: cfg.TaskConfig = cfg.TaskConfig()
eval_tasks: MultiTaskConfig = MultiTaskConfig()
trainer: cfg.TrainerConfig = cfg.TrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Optional, Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(self,
task: multitask.MultiTask,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._task = task
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
# TODO(hongkuny): Define a more robust way to handle the training/eval
# checkpoint loading.
if hasattr(self.model, "checkpoint_items"):
# Each evaluation task can have different models and load a subset of
# components from the training checkpoint. This is assuming the
# checkpoint items are able to load the weights of the evaluation model.
checkpoint_items = self.model.checkpoint_items
else:
# This is assuming the evaluation model is exactly the training model.
checkpoint_items = dict(model=self.model)
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
**checkpoint_items)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
for name, task in self.task.tasks.items():
self.eval_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
def get_function(task_name, task):
task_metrics = self.validation_metrics[task_name]
task_loss = self.validation_losses[task_name]
if isinstance(self.model, base_model.MultiTaskBaseModel):
model = self.model.sub_tasks[task_name]
else:
model = self.model
def step_fn(inputs):
logs = task.validation_step(inputs, model=model, metrics=task_metrics)
task_loss.update_state(logs[task.loss])
return logs
@tf.function
def eval_step_fn(iterator):
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
name: get_function(name, task)
for name, task in self.task.tasks.items()
}
@property
def strategy(self):
return self._strategy
@property
def task(self):
return self._task
@property
def model(self):
return self._model
@property
def global_step(self):
return self._global_step
@property
def validation_losses(self):
"""Accesses the validation loss metric object."""
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for name in self.task.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for name, task in self.task.tasks.items():
self._validation_metrics[name] = task.build_metrics(training=False)
return self._validation_metrics
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
def evaluate(self, num_steps: tf.Tensor):
"""Performs evaluation for each `EvalTask`."""
for metric in self.validation_losses.values():
metric.reset_states()
for metrics in self.validation_metrics.values():
for metric in metrics:
metric.reset_states()
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items():
outputs = None
eval_iter = eval_iters[name]
task = self.task.tasks[name]
task_eval_steps = self.task.task_eval_steps(name) or num_steps
outputs = task_eval_loop(
eval_iter,
task_eval_steps,
state=outputs,
reduce_fn=task.aggregate_logs)
task_metrics = self.validation_metrics[name]
task_loss = self.validation_losses[name]
logs = {}
for metric in task_metrics + [task_loss]:
logs[metric.name] = metric.result()
if outputs:
metrics = task.reduce_aggregated_logs(outputs)
logs.update(metrics)
results[name] = logs
return results
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class MockModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dense = tf.keras.layers.Dense(1)
def call(self, inputs):
print(inputs, type(inputs))
if "y" in inputs:
self.add_loss(tf.zeros((1,), dtype=tf.float32))
else:
self.add_loss(tf.ones((1,), dtype=tf.float32))
return self.dense(inputs["x"])
class MockTask(base_task.Task):
"""Mock task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="acc")]
def build_inputs(self, params):
def generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
if self.name == "bar":
return dict(x=x, y=x), label
else:
return dict(x=x), label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
logs = super().validation_step(inputs, model, metrics)
logs["counter"] = tf.ones((1,), dtype=tf.float32)
return logs
def aggregate_logs(self, state, step_outputs):
if state is None:
state = {}
for key, value in step_outputs.items():
if key not in state:
state[key] = []
state[key].append(
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self, aggregated_logs):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
self.assertEqual(results["bar"]["validation_loss"], 0.0)
self.assertEqual(results["foo"]["validation_loss"], 1.0)
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator_numpy_metrics(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync)
self.assertEqual(results["foo"]["counter"],
5. * distribution.num_replicas_in_sync)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import abc
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import task_factory
from official.modeling import optimization
from official.modeling import performance
from official.modeling.multitask import configs
TrainerConfig = config_definitions.TrainerConfig
RuntimeConfig = config_definitions.RuntimeConfig
class MultiTask(tf.Module, metaclass=abc.ABCMeta):
"""A multi-task class to manage multiple tasks."""
def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_mixing_steps: Optional[Dict[str, int]] = None,
task_weights: Optional[Dict[str, float]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_mixing_steps: a dict of (task, mixing steps).
task_weights: a dict of (task, loss weight).
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super().__init__(name=name)
if isinstance(tasks, list):
self._tasks = {}
for task in tasks:
if task.name in self._tasks:
raise ValueError("Duplicated tasks found, task.name is %s" %
task.name)
self._tasks[task.name] = task
elif isinstance(tasks, dict):
self._tasks = tasks
else:
raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks))
self._task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_mixing_steps = task_mixing_steps or {}
self._task_mixing_steps = dict([
(name, self._task_mixing_steps.get(name, 1)) for name in self.tasks
])
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, None)) for name in self.tasks
])
@classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {}
task_eval_steps = {}
task_mixing_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_eval_steps[task_name] = task_routine.eval_steps
task_mixing_steps[task_name] = task_routine.mixing_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks,
task_mixing_steps=task_mixing_steps,
task_eval_steps=task_eval_steps,
task_weights=task_weights)
@property
def tasks(self):
return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_mixing_steps(self, task_name):
return self._task_mixing_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
@classmethod
def create_optimizer(cls, trainer_config: TrainerConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
def joint_train_step(self, task_inputs, multi_task_model, optimizer,
task_metrics):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses = {}
with tf.GradientTape() as tape:
total_loss = 0.0
for name, model in multi_task_model.sub_tasks.items():
inputs = task_inputs[name]
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
elif isinstance(inputs, dict):
features, labels = inputs, inputs
else:
raise ValueError("The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs.")
outputs = model(features, training=True)
task_loss = self.tasks[name].build_losses(labels, outputs)
task_weight = self.task_weight(name)
total_loss += task_weight * task_loss
losses[name] = task_loss
self.tasks[name].process_metrics(task_metrics[name], labels, outputs)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
losses["total_loss"] = total_loss
return losses
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as core_lib
from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import multitask
def run_experiment_wtih_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task,
eval_tasks: multitask.MultiTask, mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str) -> tf.keras.Model:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
Returns:
model: `tf.keras.Model` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = train_task.create_optimizer(params.trainer, params.runtime)
model = train_task.build_model()
if is_training:
trainer = core_lib.Trainer(
config=params,
task=train_task,
model=model,
optimizer=optimizer,
train=True,
evaluate=False)
else:
trainer = None
if is_eval:
evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None)
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize if trainer else None)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'),
eval_summary_dir=os.path.join(model_dir, 'validation'),
summary_interval=params.trainer.summary_interval)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment