Commit 048f5a95 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Export the pretrained transformer-xl model for finetuning stage.

Add to option to init checkpoint from transformer-xl model.

PiperOrigin-RevId: 274875006
parent 188536e7
......@@ -38,6 +38,11 @@ flags.DEFINE_string(
"init_checkpoint",
default=None,
help="Checkpoint path for initializing the model.")
flags.DEFINE_bool(
"init_from_transformerxl",
default=False,
help="Init from a transformerxl model checkpoint. Otherwise, init from the "
"entire model checkpoint.")
# Optimization config
flags.DEFINE_float("learning_rate", default=1e-4, help="Maximum learning rate.")
......
......@@ -179,6 +179,7 @@ def main(unused_argv):
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
......
......@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import functools
import os
from absl import app
from absl import flags
......@@ -103,7 +104,7 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config)
training_utils.train(
model = training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
......@@ -112,6 +113,7 @@ def main(unused_argv):
train_input_fn=train_input_fn,
test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
......@@ -120,6 +122,13 @@ def main(unused_argv):
model_dir=FLAGS.model_dir,
save_steps=FLAGS.save_steps)
# Export transformer-xl model checkpoint to be used in finetuning.
checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
saved_path = checkpoint.save(
os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s",
saved_path)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
......
......@@ -284,6 +284,7 @@ def main(unused_argv):
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
......
......@@ -77,6 +77,7 @@ def train(
metric_fn: Optional[Callable[[], tf.keras.metrics.Metric]] = None,
test_input_fn: Optional[Callable] = None,
init_checkpoint: Optional[Text] = None,
init_from_transformerxl: Optional[bool] = False,
model_dir: Optional[Text] = None,
save_steps: Optional[int] = None,
run_eagerly: Optional[bool] = False):
......@@ -106,6 +107,8 @@ def train(
is skipped.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
init_from_transformerxl: Whether to load to `transformerxl_model` of
`model_fn`.
model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint.
......@@ -151,6 +154,10 @@ def train(
if init_checkpoint:
logging.info("restore from %s", init_checkpoint)
if init_from_transformerxl:
checkpoint = tf.train.Checkpoint(
transformer_xl=model.transformerxl_model)
else:
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment