run_pretrain.py 5.58 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

Hongkun Yu's avatar
Hongkun Yu committed
15
"""XLNet pretraining runner in tf2.0."""
Hongkun Yu's avatar
Hongkun Yu committed
16
17

import functools
18
import os
Hongkun Yu's avatar
Hongkun Yu committed
19

Hongkun Yu's avatar
Hongkun Yu committed
20
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
21
22
23
24
25
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
26
from official.common import distribute_utils
Hongkun Yu's avatar
Hongkun Yu committed
27
28
29
30
31
32
from official.legacy.xlnet import common_flags
from official.legacy.xlnet import data_utils
from official.legacy.xlnet import optimization
from official.legacy.xlnet import training_utils
from official.legacy.xlnet import xlnet_config
from official.legacy.xlnet import xlnet_modeling as modeling
Hongkun Yu's avatar
Hongkun Yu committed
33
34
35
36
37

flags.DEFINE_integer(
    "num_predict",
    default=None,
    help="Number of tokens to predict in partial prediction.")
Jing Li's avatar
Jing Li committed
38
39

# FLAGS for pretrain input preprocessing
Hongkun Yu's avatar
Hongkun Yu committed
40
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
Jing Li's avatar
Jing Li committed
41
42
flags.DEFINE_float("leak_ratio", default=0.1,
                   help="Percent of masked tokens that are leaked.")
Hongkun Yu's avatar
Hongkun Yu committed
43

Jing Li's avatar
Jing Li committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
flags.DEFINE_enum("sample_strategy", default="token_span",
                  enum_values=["single_token", "whole_word", "token_span",
                               "word_span"],
                  help="Stragey used to sample prediction targets.")
flags.DEFINE_integer("max_num_tokens", default=5,
                     help="Maximum number of tokens to sample in a span."
                     "Effective when token_span strategy is used.")
flags.DEFINE_integer("min_num_tokens", default=1,
                     help="Minimum number of tokens to sample in a span."
                     "Effective when token_span strategy is used.")

flags.DEFINE_integer("max_num_words", default=5,
                     help="Maximum number of whole words to sample in a span."
                     "Effective when word_span strategy is used.")
flags.DEFINE_integer("min_num_words", default=1,
                     help="Minimum number of whole words to sample in a span."
                     "Effective when word_span strategy is used.")
Hongkun Yu's avatar
Hongkun Yu committed
61
62
63
64
FLAGS = flags.FLAGS


def get_pretrainxlnet_model(model_config, run_config):
Hongkun Yu's avatar
Hongkun Yu committed
65
66
67
68
69
  return modeling.PretrainingXLNetModel(
      use_proj=True,
      xlnet_config=model_config,
      run_config=run_config,
      name="model")
Hongkun Yu's avatar
Hongkun Yu committed
70
71
72
73
74


def main(unused_argv):
  del unused_argv
  num_hosts = 1
75
  strategy = distribute_utils.get_distribution_strategy(
Hongkun Yu's avatar
Hongkun Yu committed
76
77
78
79
      distribution_strategy=FLAGS.strategy_type,
      tpu_address=FLAGS.tpu)
  if FLAGS.strategy_type == "tpu":
    num_hosts = strategy.extended.num_hosts
Hongkun Yu's avatar
Hongkun Yu committed
80
81
82
  if strategy:
    logging.info("***** Number of cores used : %d",
                 strategy.num_replicas_in_sync)
Hongkun Yu's avatar
Hongkun Yu committed
83
    logging.info("***** Number of hosts used : %d", num_hosts)
Jing Li's avatar
Jing Li committed
84
85
86
87
88
89
90
  online_masking_config = data_utils.OnlineMaskingConfig(
      sample_strategy=FLAGS.sample_strategy,
      max_num_tokens=FLAGS.max_num_tokens,
      min_num_tokens=FLAGS.min_num_tokens,
      max_num_words=FLAGS.max_num_words,
      min_num_words=FLAGS.min_num_words)

Hongkun Yu's avatar
Hongkun Yu committed
91
92
93
  train_input_fn = functools.partial(
      data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
      strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
Jing Li's avatar
Jing Li committed
94
95
      FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
      num_hosts)
Hongkun Yu's avatar
Hongkun Yu committed
96
97

  total_training_steps = FLAGS.train_steps
98

Hongkun Yu's avatar
Hongkun Yu committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  steps_per_loop = FLAGS.iterations

  optimizer, learning_rate_fn = optimization.create_optimizer(
      init_lr=FLAGS.learning_rate,
      num_train_steps=total_training_steps,
      num_warmup_steps=FLAGS.warmup_steps,
      min_lr_ratio=FLAGS.min_lr_ratio,
      adam_epsilon=FLAGS.adam_epsilon,
      weight_decay_rate=FLAGS.weight_decay_rate)

  model_config = xlnet_config.XLNetConfig(FLAGS)
  run_config = xlnet_config.create_run_config(True, False, FLAGS)
  input_meta_data = {}
  input_meta_data["d_model"] = FLAGS.d_model
  input_meta_data["mem_len"] = FLAGS.mem_len
  input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                               strategy.num_replicas_in_sync)
  input_meta_data["n_layer"] = FLAGS.n_layer
  input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
  model_fn = functools.partial(get_pretrainxlnet_model, model_config,
                               run_config)

121
  model = training_utils.train(
122
123
124
125
126
127
128
      strategy=strategy,
      model_fn=model_fn,
      input_meta_data=input_meta_data,
      eval_fn=None,
      metric_fn=None,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
129
      init_from_transformerxl=FLAGS.init_from_transformerxl,
130
131
132
133
134
135
      total_training_steps=total_training_steps,
      steps_per_loop=steps_per_loop,
      optimizer=optimizer,
      learning_rate_fn=learning_rate_fn,
      model_dir=FLAGS.model_dir,
      save_steps=FLAGS.save_steps)
Hongkun Yu's avatar
Hongkun Yu committed
136

137
138
139
140
141
142
143
  # Export transformer-xl model checkpoint to be used in finetuning.
  checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
  saved_path = checkpoint.save(
      os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
  logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s",
               saved_path)

Hongkun Yu's avatar
Hongkun Yu committed
144
145
146

if __name__ == "__main__":
  app.run(main)