"tests/csrc/unittests/test_penalty_kernels.cu" did not exist on "53d2e42cbe70898f9b13969e896a8f555d2947aa"
run_pretraining.py 6.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf

25
26
27
28
29
30
31
32
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import bert_models
from official.nlp import optimization
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
33
from official.utils.misc import distribution_utils
34
from official.utils.misc import tpu_lib
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

flags.DEFINE_string('input_files', None,
                    'File path to retrieve training data for pre-training.')
# Model training specific flags.
flags.DEFINE_integer(
    'max_seq_length', 128,
    'The maximum total input sequence length after WordPiece tokenization. '
    'Sequences longer than this will be truncated, and sequences shorter '
    'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
                     'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
                     'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
                   'Warmup steps for Adam weight decay optimizer.')

52
53
common_flags.define_common_bert_flags()

54
55
56
FLAGS = flags.FLAGS


Hongkun Yu's avatar
Hongkun Yu committed
57
58
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
                            max_predictions_per_seq, global_batch_size):
59
  """Returns input dataset from input file string."""
60
  def _dataset_fn(ctx=None):
61
    """Returns tf.data.Dataset for distributed BERT pretraining."""
Hongkun Yu's avatar
Hongkun Yu committed
62
    input_patterns = input_file_pattern.split(',')
Hongkun Yu's avatar
Hongkun Yu committed
63
    batch_size = ctx.get_per_replica_batch_size(global_batch_size)
64
    train_dataset = input_pipeline.create_pretrain_dataset(
Hongkun Yu's avatar
Hongkun Yu committed
65
        input_patterns,
66
67
68
69
70
        seq_length,
        max_predictions_per_seq,
        batch_size,
        is_training=True,
        input_pipeline_context=ctx)
71
72
    return train_dataset

Hongkun Yu's avatar
Hongkun Yu committed
73
  return _dataset_fn
74
75


76
def get_loss_fn(loss_factor=1.0):
77
78
79
  """Returns loss function for BERT pretraining."""

  def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
80
    return tf.keras.backend.mean(losses) * loss_factor
81
82
83
84
85
86
87
88
89
90

  return _bert_pretrain_loss_fn


def run_customized_training(strategy,
                            bert_config,
                            max_seq_length,
                            max_predictions_per_seq,
                            model_dir,
                            steps_per_epoch,
91
                            steps_per_loop,
92
93
94
95
                            epochs,
                            initial_lr,
                            warmup_steps,
                            input_files,
96
                            train_batch_size):
97
98
  """Run BERT pretrain model training using low-level API."""

Hongkun Yu's avatar
Hongkun Yu committed
99
100
101
  train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
                                           max_predictions_per_seq,
                                           train_batch_size)
102
103

  def _get_pretrain_model():
104
    """Gets a pretraining model."""
105
106
107
108
    pretrain_model, core_model = bert_models.pretrain_model(
        bert_config, max_seq_length, max_predictions_per_seq)
    pretrain_model.optimizer = optimization.create_optimizer(
        initial_lr, steps_per_epoch * epochs, warmup_steps)
109
110
111
112
113
114
115
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
116
          pretrain_model.optimizer)
117
118
    return pretrain_model, core_model

119
  trained_model = model_training_utils.run_customized_training_loop(
120
121
      strategy=strategy,
      model_fn=_get_pretrain_model,
122
123
124
      loss_fn=get_loss_fn(
          loss_factor=1.0 /
          strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
125
126
127
      model_dir=model_dir,
      train_input_fn=train_input_fn,
      steps_per_epoch=steps_per_epoch,
128
      steps_per_loop=steps_per_loop,
Chen Chen's avatar
Chen Chen committed
129
130
      epochs=epochs,
      sub_model_export_name='pretrained/bert_model')
131

132
133
  return trained_model

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

def run_bert_pretrain(strategy):
  """Runs BERT pre-training."""

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  if not strategy:
    raise ValueError('Distribution strategy is not specified.')

  # Runs customized training loop.
  logging.info('Training using customized training loop TF 2.0 with distrubuted'
               'strategy.')

  return run_customized_training(
      strategy,
      bert_config,
      FLAGS.max_seq_length,
      FLAGS.max_predictions_per_seq,
      FLAGS.model_dir,
      FLAGS.num_steps_per_epoch,
153
      FLAGS.steps_per_loop,
154
155
156
157
      FLAGS.num_train_epochs,
      FLAGS.learning_rate,
      FLAGS.warmup_steps,
      FLAGS.input_files,
158
      FLAGS.train_batch_size)
159
160
161
162
163


def main(_):
  # Users should always run this script under TF 2.x
  assert tf.version.VERSION.startswith('2.')
164

165
166
  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'
167
168
169
170
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
171
172
173
  if strategy:
    print('***** Number of cores used : ', strategy.num_replicas_in_sync)

174
  run_bert_pretrain(strategy)
175
176
177
178


if __name__ == '__main__':
  app.run(main)