"...resnet50_tensorflow.git" did not exist on "fd7b6887fb294e90356d5664724083d1f61671ef"
Commit c60951b1 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix a bug that the transformer model cannot be saved when save_weights_only is...

Fix a bug that the transformer model cannot be saved when save_weights_only is set to False: https://github.com/tensorflow/models/issues/9186

PiperOrigin-RevId: 330292460
parent 4a577082
...@@ -209,6 +209,14 @@ def define_transformer_flags(): ...@@ -209,6 +209,14 @@ def define_transformer_flags():
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether to do checkpointing during training. When running under ' 'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.')) 'benchmark harness, we will avoid checkpointing.'))
flags.DEFINE_bool(
name='save_weights_only',
default=True,
help=flags_core.help_wrap(
'Only used when above `enable_checkpointing` is True. '
'If True, then only the model\'s weights will be saved '
'(`model.save_weights(filepath)`), else the full model is saved '
'(`model.save(filepath)`)'))
flags_core.set_defaults( flags_core.set_defaults(
data_dir='/tmp/translate_ende', data_dir='/tmp/translate_ende',
......
...@@ -35,7 +35,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -35,7 +35,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
super(LearningRateSchedule, self).__init__() super(LearningRateSchedule, self).__init__()
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.warmup_steps = tf.cast(warmup_steps, tf.float32) self.warmup_steps = warmup_steps
self.warmup_steps_tensor = tf.cast(warmup_steps, tf.float32)
def __call__(self, global_step): def __call__(self, global_step):
"""Calculate learning rate with linear warmup and rsqrt decay. """Calculate learning rate with linear warmup and rsqrt decay.
...@@ -52,9 +53,10 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -52,9 +53,10 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
learning_rate = self.initial_learning_rate learning_rate = self.initial_learning_rate
learning_rate *= (self.hidden_size**-0.5) learning_rate *= (self.hidden_size**-0.5)
# Apply linear warmup # Apply linear warmup
learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps) learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps_tensor)
# Apply rsqrt decay # Apply rsqrt decay
learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps)) learning_rate /= tf.sqrt(
tf.maximum(global_step, self.warmup_steps_tensor))
return learning_rate return learning_rate
def get_config(self): def get_config(self):
......
...@@ -415,7 +415,7 @@ class TransformerTask(object): ...@@ -415,7 +415,7 @@ class TransformerTask(object):
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt") ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append( callbacks.append(
tf.keras.callbacks.ModelCheckpoint( tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True)) ckpt_full_path, save_weights_only=params["save_weights_only"]))
return callbacks return callbacks
def _load_weights_if_possible(self, model, init_weight_path=None): def _load_weights_if_possible(self, model, init_weight_path=None):
......
...@@ -86,6 +86,13 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -86,6 +86,13 @@ class TransformerTaskTest(tf.test.TestCase):
t = transformer_main.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
def test_train_save_full_model(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.save_weights_only = False
t = transformer_main.TransformerTask(FLAGS)
t.train()
def test_train_static_batch(self): def test_train_static_batch(self):
if context.num_gpus() >= 2: if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.') self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment