Commit eb5d69ae authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Support new Keras optimizers (after 2.10) for the multitask trainer.

Users must to create all model variables before training.

PiperOrigin-RevId: 476782148
parent 34d91a7c
...@@ -43,3 +43,12 @@ class MultiTaskBaseModel(tf.Module): ...@@ -43,3 +43,12 @@ class MultiTaskBaseModel(tf.Module):
def initialize(self): def initialize(self):
"""Optional function that loads a pre-train checkpoint.""" """Optional function that loads a pre-train checkpoint."""
return return
def build(self):
"""Builds the networks for tasks to make sure variables are created."""
# Try to build all sub tasks.
for task_model in self._sub_tasks.values():
# Assumes all the tf.Module models are built because we don't have any
# way to check them.
if isinstance(task_model, tf.keras.Model) and not task_model.built:
_ = task_model(task_model.inputs)
...@@ -31,7 +31,9 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -31,7 +31,9 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
multi_task: multitask.MultiTask, multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model, multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel], base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer, optimizer: Union[tf.optimizers.Optimizer,
tf.keras.optimizers.experimental.Optimizer,
tf.keras.optimizers.legacy.Optimizer],
task_sampler: sampler.TaskSampler, task_sampler: sampler.TaskSampler,
trainer_options=None): trainer_options=None):
super().__init__( super().__init__(
...@@ -69,6 +71,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): ...@@ -69,6 +71,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
name: orbit.utils.create_global_step() for name in self.multi_task.tasks name: orbit.utils.create_global_step() for name in self.multi_task.tasks
} }
# If the new Keras optimizer is used, we require all model variables are
# created before the training and let the optimizer to create the slot
# variable all together.
if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
multi_task_model.build()
optimizer.build(multi_task_model.trainable_variables)
def task_step_counter(self, name): def task_step_counter(self, name):
return self._task_step_counters[name] return self._task_step_counters[name]
......
...@@ -28,6 +28,8 @@ class MockFooModel(tf.keras.Model): ...@@ -28,6 +28,8 @@ class MockFooModel(tf.keras.Model):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._share_layer = shared_layer self._share_layer = shared_layer
self._foo_specific_layer = tf.keras.layers.Dense(1) self._foo_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"foo": tf.keras.Input(shape=(2,), dtype=tf.float32),
"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs): def call(self, inputs):
self.add_loss(tf.zeros((1,), dtype=tf.float32)) self.add_loss(tf.zeros((1,), dtype=tf.float32))
...@@ -39,11 +41,13 @@ class MockFooModel(tf.keras.Model): ...@@ -39,11 +41,13 @@ class MockFooModel(tf.keras.Model):
class MockBarModel(tf.keras.Model): class MockBarModel(tf.keras.Model):
"""A mock model can only consume 'bar' inputs."""
def __init__(self, shared_layer, *args, **kwargs): def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._share_layer = shared_layer self._share_layer = shared_layer
self._bar_specific_layer = tf.keras.layers.Dense(1) self._bar_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs): def call(self, inputs):
self.add_loss(tf.zeros((2,), dtype=tf.float32)) self.add_loss(tf.zeros((2,), dtype=tf.float32))
......
...@@ -98,7 +98,6 @@ def run_experiment( ...@@ -98,7 +98,6 @@ def run_experiment(
checkpoint = evaluator.checkpoint checkpoint = evaluator.checkpoint
global_step = evaluator.global_step global_step = evaluator.global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
directory=model_dir, directory=model_dir,
......
...@@ -58,8 +58,9 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -58,8 +58,9 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
mode='eager', mode='eager',
optimizer=['sgd_experimental', 'sgd'],
flag_mode=['train', 'eval', 'train_and_eval'])) flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end(self, distribution_strategy, flag_mode): def test_end_to_end(self, distribution_strategy, optimizer, flag_mode):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
experiment_config = configs.MultiTaskExperimentConfig( experiment_config = configs.MultiTaskExperimentConfig(
task=configs.MultiTaskConfig( task=configs.MultiTaskConfig(
...@@ -70,6 +71,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -70,6 +71,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task_name='bar', task_config=test_utils.BarConfig())))) task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False) experiment_config, self._test_config, is_strict=False)
experiment_config.trainer.optimizer_config.optimizer.type = optimizer
with distribution_strategy.scope(): with distribution_strategy.scope():
test_multitask = multitask.MultiTask.from_config(experiment_config.task) test_multitask = multitask.MultiTask.from_config(experiment_config.task)
model = test_utils.MockMultiTaskModel() model = test_utils.MockMultiTaskModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment