Commit eb5d69ae authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Support new Keras optimizers (after 2.10) for the multitask trainer.

Users must to create all model variables before training.

PiperOrigin-RevId: 476782148
parent 34d91a7c
......@@ -43,3 +43,12 @@ class MultiTaskBaseModel(tf.Module):
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
def build(self):
"""Builds the networks for tasks to make sure variables are created."""
# Try to build all sub tasks.
for task_model in self._sub_tasks.values():
# Assumes all the tf.Module models are built because we don't have any
# way to check them.
if isinstance(task_model, tf.keras.Model) and not task_model.built:
_ = task_model(task_model.inputs)
......@@ -31,7 +31,9 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
optimizer: Union[tf.optimizers.Optimizer,
tf.keras.optimizers.experimental.Optimizer,
tf.keras.optimizers.legacy.Optimizer],
task_sampler: sampler.TaskSampler,
trainer_options=None):
super().__init__(
......@@ -69,6 +71,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}
# If the new Keras optimizer is used, we require all model variables are
# created before the training and let the optimizer to create the slot
# variable all together.
if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
multi_task_model.build()
optimizer.build(multi_task_model.trainable_variables)
def task_step_counter(self, name):
return self._task_step_counters[name]
......
......@@ -28,6 +28,8 @@ class MockFooModel(tf.keras.Model):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._foo_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"foo": tf.keras.Input(shape=(2,), dtype=tf.float32),
"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs):
self.add_loss(tf.zeros((1,), dtype=tf.float32))
......@@ -39,11 +41,13 @@ class MockFooModel(tf.keras.Model):
class MockBarModel(tf.keras.Model):
"""A mock model can only consume 'bar' inputs."""
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._bar_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs):
self.add_loss(tf.zeros((2,), dtype=tf.float32))
......
......@@ -98,7 +98,6 @@ def run_experiment(
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
......
......@@ -58,8 +58,9 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
optimizer=['sgd_experimental', 'sgd'],
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end(self, distribution_strategy, flag_mode):
def test_end_to_end(self, distribution_strategy, optimizer, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiTaskExperimentConfig(
task=configs.MultiTaskConfig(
......@@ -70,6 +71,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
experiment_config.trainer.optimizer_config.optimizer.type = optimizer
with distribution_strategy.scope():
test_multitask = multitask.MultiTask.from_config(experiment_config.task)
model = test_utils.MockMultiTaskModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment