run_pretraining.py 8.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Hongkun Yu's avatar
Hongkun Yu committed
15
"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
16
17
18
19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
20
# Import libraries
21
22
23
from absl import app
from absl import flags
from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
24
import gin
25
import tensorflow as tf
26
from official.modeling import performance
27
from official.nlp import optimization
28
from official.nlp.bert import bert_models
29
from official.nlp.bert import common_flags
30
from official.nlp.bert import configs
31
from official.nlp.bert import input_pipeline
32
from official.nlp.bert import model_training_utils
33
from official.utils.misc import distribution_utils
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

flags.DEFINE_string('input_files', None,
                    'File path to retrieve training data for pre-training.')
# Model training specific flags.
flags.DEFINE_integer(
    'max_seq_length', 128,
    'The maximum total input sequence length after WordPiece tokenization. '
    'Sequences longer than this will be truncated, and sequences shorter '
    'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
                     'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
                     'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
                   'Warmup steps for Adam weight decay optimizer.')
51
52
flags.DEFINE_bool('use_next_sentence_label', True,
                  'Whether to use next sentence label to compute final loss.')
Chen Chen's avatar
Chen Chen committed
53
54
55
flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
                  'summaries. If the value is a negative number, '
                  'then training summaries are not enabled.')
56

57
58
common_flags.define_common_bert_flags()

59
60
61
FLAGS = flags.FLAGS


Hongkun Yu's avatar
Hongkun Yu committed
62
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
63
64
                            max_predictions_per_seq, global_batch_size,
                            use_next_sentence_label=True):
65
  """Returns input dataset from input file string."""
66
  def _dataset_fn(ctx=None):
67
    """Returns tf.data.Dataset for distributed BERT pretraining."""
Hongkun Yu's avatar
Hongkun Yu committed
68
    input_patterns = input_file_pattern.split(',')
Hongkun Yu's avatar
Hongkun Yu committed
69
    batch_size = ctx.get_per_replica_batch_size(global_batch_size)
70
    train_dataset = input_pipeline.create_pretrain_dataset(
Hongkun Yu's avatar
Hongkun Yu committed
71
        input_patterns,
72
73
74
75
        seq_length,
        max_predictions_per_seq,
        batch_size,
        is_training=True,
76
77
        input_pipeline_context=ctx,
        use_next_sentence_label=use_next_sentence_label)
78
79
    return train_dataset

Hongkun Yu's avatar
Hongkun Yu committed
80
  return _dataset_fn
81
82


83
def get_loss_fn():
84
85
86
  """Returns loss function for BERT pretraining."""

  def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
87
    return tf.reduce_mean(losses)
88
89
90
91
92
93

  return _bert_pretrain_loss_fn


def run_customized_training(strategy,
                            bert_config,
André Susano Pinto's avatar
André Susano Pinto committed
94
                            init_checkpoint,
95
96
97
98
                            max_seq_length,
                            max_predictions_per_seq,
                            model_dir,
                            steps_per_epoch,
99
                            steps_per_loop,
100
101
102
                            epochs,
                            initial_lr,
                            warmup_steps,
103
104
                            end_lr,
                            optimizer_type,
105
                            input_files,
106
                            train_batch_size,
Chen Chen's avatar
Chen Chen committed
107
                            use_next_sentence_label=True,
Chen Chen's avatar
Chen Chen committed
108
                            train_summary_interval=0,
Zongwei Zhou's avatar
Zongwei Zhou committed
109
110
111
112
                            custom_callbacks=None,
                            explicit_allreduce=False,
                            pre_allreduce_callbacks=None,
                            post_allreduce_callbacks=None):
113
114
  """Run BERT pretrain model training using low-level API."""

Hongkun Yu's avatar
Hongkun Yu committed
115
116
  train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
                                           max_predictions_per_seq,
117
118
                                           train_batch_size,
                                           use_next_sentence_label)
119
120

  def _get_pretrain_model():
121
    """Gets a pretraining model."""
122
    pretrain_model, core_model = bert_models.pretrain_model(
123
124
        bert_config, max_seq_length, max_predictions_per_seq,
        use_next_sentence_label=use_next_sentence_label)
125
    optimizer = optimization.create_optimizer(
126
        initial_lr, steps_per_epoch * epochs, warmup_steps,
127
        end_lr, optimizer_type)
128
129
130
131
    pretrain_model.optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=common_flags.use_float16(),
        use_graph_rewrite=common_flags.use_graph_rewrite())
132
133
    return pretrain_model, core_model

134
  trained_model = model_training_utils.run_customized_training_loop(
135
136
      strategy=strategy,
      model_fn=_get_pretrain_model,
137
138
      loss_fn=get_loss_fn(),
      scale_loss=FLAGS.scale_loss,
139
      model_dir=model_dir,
André Susano Pinto's avatar
André Susano Pinto committed
140
      init_checkpoint=init_checkpoint,
141
142
      train_input_fn=train_input_fn,
      steps_per_epoch=steps_per_epoch,
143
      steps_per_loop=steps_per_loop,
Chen Chen's avatar
Chen Chen committed
144
      epochs=epochs,
Chen Chen's avatar
Chen Chen committed
145
      sub_model_export_name='pretrained/bert_model',
Zongwei Zhou's avatar
Zongwei Zhou committed
146
147
148
      explicit_allreduce=explicit_allreduce,
      pre_allreduce_callbacks=pre_allreduce_callbacks,
      post_allreduce_callbacks=post_allreduce_callbacks,
Chen Chen's avatar
Chen Chen committed
149
      train_summary_interval=train_summary_interval,
Chen Chen's avatar
Chen Chen committed
150
      custom_callbacks=custom_callbacks)
151

152
153
  return trained_model

154

Chen Chen's avatar
Chen Chen committed
155
def run_bert_pretrain(strategy, custom_callbacks=None):
156
157
  """Runs BERT pre-training."""

158
  bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
159
160
161
162
  if not strategy:
    raise ValueError('Distribution strategy is not specified.')

  # Runs customized training loop.
Chen Chen's avatar
Chen Chen committed
163
  logging.info('Training using customized training loop TF 2.0 with distributed'
164
165
               'strategy.')

166
167
  performance.set_mixed_precision_policy(common_flags.dtype())

Zongwei Zhou's avatar
Zongwei Zhou committed
168
169
170
171
  # If explicit_allreduce = True, apply_gradients() no longer implicitly
  # allreduce gradients, users manually allreduce gradient and pass the
  # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
  # before allreduce, to be consistent with original TF1 model.
172
173
174
  return run_customized_training(
      strategy,
      bert_config,
André Susano Pinto's avatar
André Susano Pinto committed
175
      FLAGS.init_checkpoint,  # Used to initialize only the BERT submodel.
176
177
178
179
      FLAGS.max_seq_length,
      FLAGS.max_predictions_per_seq,
      FLAGS.model_dir,
      FLAGS.num_steps_per_epoch,
180
      FLAGS.steps_per_loop,
181
182
183
      FLAGS.num_train_epochs,
      FLAGS.learning_rate,
      FLAGS.warmup_steps,
184
185
      FLAGS.end_lr,
      FLAGS.optimizer_type,
186
      FLAGS.input_files,
187
      FLAGS.train_batch_size,
Chen Chen's avatar
Chen Chen committed
188
      FLAGS.use_next_sentence_label,
Chen Chen's avatar
Chen Chen committed
189
      FLAGS.train_summary_interval,
Zongwei Zhou's avatar
Zongwei Zhou committed
190
191
      custom_callbacks=custom_callbacks,
      explicit_allreduce=FLAGS.explicit_allreduce,
192
193
194
      pre_allreduce_callbacks=[
          model_training_utils.clip_by_global_norm_callback
      ])
195
196
197


def main(_):
Hongkun Yu's avatar
Hongkun Yu committed
198
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
199
200
  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'
Zongwei Zhou's avatar
Zongwei Zhou committed
201
202
203
204
  # Configures cluster spec for multi-worker distribution strategy.
  if FLAGS.num_gpus > 0:
    _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
                                             FLAGS.task_index)
205
206
207
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
Zongwei Zhou's avatar
Zongwei Zhou committed
208
      all_reduce_alg=FLAGS.all_reduce_alg,
209
      tpu_address=FLAGS.tpu)
210
211
212
  if strategy:
    print('***** Number of cores used : ', strategy.num_replicas_in_sync)

213
  run_bert_pretrain(strategy)
214
215
216
217


if __name__ == '__main__':
  app.run(main)