Commit 0bb46a95 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Open source the second half of multi-task library

PiperOrigin-RevId: 365085378
parent 026a7880
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskBaseTrainer(orbit.StandardTrainer):
"""Multitask base trainer."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
self._optimizer = optimizer
self._training_losses = None
self._training_metrics = None
self._global_step = orbit.utils.create_global_step()
if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.multi_task_model,
optimizer=self.optimizer,
global_step=self.global_step,
**checkpoint_items)
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
options=trainer_options or orbit.StandardTrainerOptions())
def train_loop_begin(self):
"""Clean up states that hold losses and metrics."""
for _, train_loss_metric in self.training_losses.items():
train_loss_metric.reset_states()
for _, metrics in self.training_metrics.items():
for metric in metrics:
metric.reset_states()
def train_loop_end(self):
"""Record loss and metric values per task."""
result = {}
for task_name, loss in self.training_losses.items():
result[task_name] = {loss.name: loss.result()}
for task_name, task_metrics in self.training_metrics.items():
result[task_name].update(
{metric.name: metric.result() for metric in task_metrics})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if callable(self.optimizer.learning_rate):
result["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
result["learning_rate"] = self.optimizer.learning_rate
return result
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
@property
def training_losses(self):
"""Access training loss metric objects for all tasks."""
if self._training_losses is None:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self._training_losses = dict(
total_loss=tf.keras.metrics.Mean("training_loss", dtype=tf.float32))
for name in self.multi_task.tasks:
self._training_losses[name] = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
return self._training_losses
@property
def training_metrics(self):
"""Access training metric metric objects for all tasks."""
if self._training_metrics is None:
# Builds the per-task metrics and losses.
self._training_metrics = {}
for name, task in self.multi_task.tasks.items():
self._training_metrics[name] = task.build_metrics(training=True)
return self._training_metrics
@property
def strategy(self):
return self._strategy
@property
def multi_task(self):
return self._multi_task
@property
def multi_task_model(self):
return self._multi_task_model
@property
def optimizer(self):
return self._optimizer
@property
def global_step(self):
return self._global_step
def train_step(self, iterator_map):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def step_fn(inputs):
losses = self.multi_task.joint_train_step(
inputs,
multi_task_model=self.multi_task_model,
optimizer=self.optimizer,
task_metrics=self.training_metrics)
for key, loss in losses.items():
self.training_losses[key].update_state(loss)
self.strategy.run(
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
self.global_step.assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_joint_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
task_weights = {"foo": 1.0, "bar": 1.0}
test_multitask = multitask.MultiTask(
tasks=tasks, task_weights=task_weights)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
def test_trainer_with_configs(self):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=0.5),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=0.5)))
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_multitask.task_weight("foo"), 0.5)
self.assertEqual(test_trainer.global_step.numpy(), 5)
self.assertIn("learning_rate", results)
if __name__ == "__main__":
tf.test.main()
...@@ -36,6 +36,39 @@ class MultiTaskConfig(hyperparams.Config): ...@@ -36,6 +36,39 @@ class MultiTaskConfig(hyperparams.Config):
task_routines: Tuple[TaskRoutine, ...] = () task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class ProportionalSampleConfig(hyperparams.Config):
alpha: float = 1.0
@dataclasses.dataclass
class AnnealingSampleConfig(hyperparams.Config):
steps_per_epoch: int = 5
total_steps: int = 20
@dataclasses.dataclass
class TaskSamplingConfig(hyperparams.OneOfConfig):
type: str = ""
uniform: hyperparams.Config = hyperparams.Config()
proportional: ProportionalSampleConfig = ProportionalSampleConfig()
annealing: AnnealingSampleConfig = AnnealingSampleConfig()
@dataclasses.dataclass
class MultiTaskTrainerConfig(cfg.TrainerConfig):
trainer_type: str = "interleaving"
task_sampler: TaskSamplingConfig = TaskSamplingConfig(type="proportional")
@dataclasses.dataclass
class MultiTaskExperimentConfig(hyperparams.Config):
"""An experiment config for multi-task training and multi-task evaluation."""
task: MultiTaskConfig = MultiTaskConfig()
trainer: MultiTaskTrainerConfig = MultiTaskTrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
@dataclasses.dataclass @dataclasses.dataclass
class MultiEvalExperimentConfig(cfg.ExperimentConfig): class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation. """An experiment config for single-task training and multi-task evaluation.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler as sampler
@gin.configurable
class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
"""MultiTask trainer that interleaves task update."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler,
trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__(
multi_task=multi_task,
multi_task_model=multi_task_model,
optimizer=optimizer,
trainer_options=trainer_options)
self._task_sampler = task_sampler
# Build per task train step.
def _get_task_step(task_name, task):
def step_fn(inputs):
if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel):
task_model = self.multi_task_model.sub_tasks[task_name]
else:
task_model = self.multi_task_model
task_logs = task.train_step(
inputs,
model=task_model,
optimizer=self.optimizer,
metrics=self.training_metrics[task_name])
self.training_losses[task_name].update_state(task_logs[task.loss])
return step_fn
self._task_train_step_map = {
name: _get_task_step(name, task)
for name, task in self.multi_task.tasks.items()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self._task_step_counters = {
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}
def task_step_counter(self, name):
return self._task_step_counters[name]
def train_step(self, iterator_map):
# Sample one task to train according to a multinomial distribution
rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step))
cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution(
self.global_step)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution = tf.concat(
[tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution],
axis=0)
for idx, (name, _) in enumerate(self.multi_task.tasks.items()):
begin = cumulative_sample_distribution[idx]
end = cumulative_sample_distribution[idx + 1]
if rn >= begin and rn < end:
self._strategy.run(
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import configs
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_interleaving_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
test_multitask = multitask.MultiTask(tasks=tasks)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
sampler = task_sampler.UniformTaskSampler(
task_weights=test_multitask.task_weights)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
@combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=3.0),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=1.0)))
with distribution.scope():
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
num_step = 1000
sampler = task_sampler.AnnealingTaskSampler(
task_weights=test_multitask.task_weights,
steps_per_epoch=num_step/5,
total_steps=num_step)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(num_step, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_trainer.global_step.numpy(), num_step)
bar_sampled_step = test_trainer.task_step_counter("bar").numpy()
foo_sampled_step = test_trainer.task_step_counter("foo").numpy()
self.assertEqual(bar_sampled_step + foo_sampled_step, num_step)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import abc
from typing import Union, Dict, Text
import tensorflow as tf
from official.modeling.multitask import configs
class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining task sampling API for interleaving trainer."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
self._task_weights = task_weights
@abc.abstractmethod
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class UniformTaskSampler(TaskSampler):
"""Sample all tasks uniformly."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
super(UniformTaskSampler, self).__init__(task_weights=task_weights)
self._uniform_cumulative = tf.math.cumsum(
tf.constant(
[1.0 / len(self._task_weights)] * len(self._task_weights),
dtype=tf.float32))
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._uniform_cumulative
class ProportionalTaskSampler(TaskSampler):
"""Sample tasks proportional to task weights."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
alpha: float = 1.0):
super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
self._alpha = tf.cast(alpha, dtype=tf.float32)
task_weight_dict_ordered_list = tf.constant(
[weight for _, weight in self._task_weights.items()], dtype=tf.float32)
task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
task_distribution = task_sizes / tf.reduce_sum(task_sizes)
self._porportional_cumulative = tf.math.cumsum(task_distribution)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._porportional_cumulative
class AnnealingTaskSampler(TaskSampler):
"""Sample tasks according to task weights as well as training progress."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
steps_per_epoch: int,
total_steps: int):
super(AnnealingTaskSampler, self).__init__(task_weights=task_weights)
self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32)
self._total_epochs = tf.cast(
total_steps / self._steps_per_epoch, dtype=tf.float32)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
cur_epoch = tf.math.floor(
tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch)
alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10)
task_weight_dict_ordered_list = [
weight for _, weight in self._task_weights.items()
]
task_sizes = tf.math.pow(
tf.constant(task_weight_dict_ordered_list, dtype=tf.float32),
tf.cast(alpha, dtype=tf.float32))
dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes)
return tf.math.cumsum(dynamic_task_distribution)
def get_task_sampler(config: configs.TaskSamplingConfig,
task_weights: Dict[Text, float]) -> TaskSampler:
"""Utils to create task sampler with configuration and task weights."""
oneof_config = config.get()
if config.type == 'uniform':
return UniformTaskSampler(task_weights=task_weights)
elif config.type == 'proportional':
return ProportionalTaskSampler(
task_weights=task_weights, alpha=oneof_config.alpha)
elif config.type == 'annealing':
return AnnealingTaskSampler(
task_weights=task_weights,
steps_per_epoch=oneof_config.steps_per_epoch,
total_steps=oneof_config.total_steps)
else:
raise RuntimeError('Task sampler type not supported')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import tensorflow as tf
from official.modeling.multitask import configs
from official.modeling.multitask import task_sampler as sampler
class TaskSamplerTest(tf.test.TestCase):
def setUp(self):
super(TaskSamplerTest, self).setUp()
self._task_weights = {'A': 1.0, 'B': 2.0, 'C': 3.0}
def test_uniform_sample_distribution(self):
uniform_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(type='uniform'), self._task_weights)
for step in range(5):
cumulative_distribution = uniform_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.333333, 0.666666, 1.0],
cumulative_distribution.numpy())
def test_proportional_sample_distribution(self):
prop_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='proportional',
proportional=configs.ProportionalSampleConfig(alpha=2.0)),
self._task_weights)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for step in range(5):
cumulative_distribution = prop_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.07142857, 0.35714286, 1.0],
cumulative_distribution.numpy())
def test_annealing_sample_distribution(self):
num_epoch = 3
step_per_epoch = 6
annel_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='annealing',
annealing=configs.AnnealingSampleConfig(
steps_per_epoch=step_per_epoch,
total_steps=step_per_epoch * num_epoch)), self._task_weights)
global_step = tf.Variable(
0, dtype=tf.int64, name='global_step', trainable=False)
expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0],
[0.16666667, 0.5, 1.0],
[0.22477472, 0.5654695, 1.0]]
for epoch in range(num_epoch):
for _ in range(step_per_epoch):
cumulative_distribution = annel_sampler.task_cumulative_distribution(
tf.constant(global_step, dtype=tf.int64))
global_step.assign_add(1)
self.assertAllClose(expected_cumulative_epochs[epoch],
cumulative_distribution.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utils for mock models and tasks."""
from typing import Dict, Text
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.multitask import base_model
class MockFooModel(tf.keras.Model):
"""A mock model can consume 'foo' and 'bar' inputs."""
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._foo_specific_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
self.add_loss(tf.zeros((1,), dtype=tf.float32))
if "foo" in inputs:
input_tensor = inputs["foo"]
else:
input_tensor = inputs["bar"]
return self._foo_specific_layer(self._share_layer(input_tensor))
class MockBarModel(tf.keras.Model):
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._bar_specific_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
self.add_loss(tf.zeros((2,), dtype=tf.float32))
return self._bar_specific_layer(self._share_layer(inputs["bar"]))
class MockMultiTaskModel(base_model.MultiTaskBaseModel):
def __init__(self, *args, **kwargs):
self._shared_dense = tf.keras.layers.Dense(1)
super().__init__(*args, **kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
return {
"foo": MockFooModel(self._shared_dense),
"bar": MockBarModel(self._shared_dense)
}
def mock_data(feature_name):
"""Mock dataset function."""
def _generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return {feature_name: x}, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
_generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
class FooConfig(cfg.TaskConfig):
pass
class BarConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(FooConfig)
class MockFooTask(base_task.Task):
"""Mock foo task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="foo_acc")]
def build_inputs(self, params):
return mock_data("foo")
def build_model(self) -> tf.keras.Model:
return MockFooModel(shared_layer=tf.keras.layers.Dense(1))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
@task_factory.register_task_cls(BarConfig)
class MockBarTask(base_task.Task):
"""Mock bar task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="bar_acc")]
def build_inputs(self, params):
return mock_data("bar")
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
...@@ -21,9 +21,113 @@ import tensorflow as tf ...@@ -21,9 +21,113 @@ import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import base_trainer as core_lib from official.core import base_trainer as core_lib
from official.core import train_utils from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
TRAINERS = {
'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
'joint': base_trainer.MultiTaskBaseTrainer
}
def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel, mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str) -> base_model.MultiTaskBaseModel:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A MultiTaskTask instance.
model: A MultiTaskBaseModel instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
if params.trainer.trainer_type == 'interleaving':
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights)
kwargs.update(dict(task_sampler=sampler))
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
evaluator = evaluator_lib.MultiTaskEvaluator(
task=task,
model=model,
global_step=trainer.global_step if is_training else None)
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=model.initialize)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'),
eval_summary_dir=os.path.join(model_dir, 'validation'),
summary_interval=params.trainer.summary_interval)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
return model
def run_experiment_with_multitask_eval( def run_experiment_with_multitask_eval(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.train_lib."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import task_factory
from official.modeling.hyperparams import params_dict
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
from official.modeling.multitask import train_lib
class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiTaskExperimentConfig(
task=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
test_multitask = multitask.MultiTask.from_config(experiment_config.task)
model = test_utils.MockMultiTaskModel()
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=test_multitask,
model=model,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks)
train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=train_task,
eval_tasks=eval_tasks,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment