run_classifier.py 9.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT classification finetuning runner in tf2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import json
import math

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf

30
# Import BERT model libraries.
31
from official.bert import bert_models
32
from official.bert import common_flags
33
34
35
36
37
from official.bert import input_pipeline
from official.bert import model_saving_utils
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
38
from official.utils.misc import keras_utils
39
from official.utils.misc import tpu_lib
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

flags.DEFINE_enum(
    'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
    'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
    'trains the model and evaluates in the meantime. '
    '`export_only`: will take the latest checkpoint inside '
    'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', None,
                    'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
                    'Path to evaluation data for BERT classifier.')
# Model training specific flags.
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
57
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
58
59

common_flags.define_common_bert_flags()
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

FLAGS = flags.FLAGS


def get_loss_fn(num_classes, loss_scale=1.0):
  """Gets the classification loss function."""

  def classification_loss_fn(labels, logits):
    """Classification loss."""
    labels = tf.squeeze(labels)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    one_hot_labels = tf.one_hot(
        tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(
        tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    loss *= loss_scale
    return loss

  return classification_loss_fn


def run_customized_training(strategy,
                            bert_config,
                            input_meta_data,
                            model_dir,
                            epochs,
                            steps_per_epoch,
88
                            steps_per_loop,
89
90
91
92
                            eval_steps,
                            warmup_steps,
                            initial_lr,
                            init_checkpoint,
93
                            use_remote_tpu=False,
94
95
                            custom_callbacks=None,
                            run_eagerly=False):
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  """Run BERT classifier training using low-level API."""
  max_seq_length = input_meta_data['max_seq_length']
  num_classes = input_meta_data['num_labels']

  train_input_fn = functools.partial(
      input_pipeline.create_classifier_dataset,
      FLAGS.train_data_path,
      seq_length=max_seq_length,
      batch_size=FLAGS.train_batch_size)
  eval_input_fn = functools.partial(
      input_pipeline.create_classifier_dataset,
      FLAGS.eval_data_path,
      seq_length=max_seq_length,
      batch_size=FLAGS.eval_batch_size,
      is_training=False,
      drop_remainder=False)

  def _get_classifier_model():
    classifier_model, core_model = (
        bert_models.classifier_model(bert_config, tf.float32, num_classes,
                                     max_seq_length))
    classifier_model.optimizer = optimization.create_optimizer(
        initial_lr, steps_per_epoch * epochs, warmup_steps)
119
120
121
122
123
124
125
126
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                                 classifier_model.optimizer)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    return classifier_model, core_model

  loss_fn = get_loss_fn(num_classes, loss_scale=1.0)

  # Defines evaluation metrics function, which will create metrics in the
  # correct device and strategy scope.
  def metric_fn():
    return tf.keras.metrics.SparseCategoricalAccuracy(
        'test_accuracy', dtype=tf.float32)

  return model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_classifier_model,
      loss_fn=loss_fn,
      model_dir=model_dir,
      steps_per_epoch=steps_per_epoch,
143
      steps_per_loop=steps_per_loop,
144
145
146
147
148
149
      epochs=epochs,
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      eval_steps=eval_steps,
      init_checkpoint=init_checkpoint,
      metric_fn=metric_fn,
150
      use_remote_tpu=use_remote_tpu,
151
152
      custom_callbacks=custom_callbacks,
      run_eagerly=run_eagerly)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168


def export_classifier(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

169
170
171
  classifier_model = bert_models.classifier_model(
      bert_config, tf.float32, input_meta_data['num_labels'],
      input_meta_data['max_seq_length'])[0]
172
  model_saving_utils.export_bert_model(
173
      model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir)
174
175
176
177
178
179
180
181
182
183


def run_bert(strategy, input_meta_data):
  """Run BERT training."""
  if FLAGS.mode == 'export_only':
    export_classifier(FLAGS.model_export_path, input_meta_data)
    return

  if FLAGS.mode != 'train_and_eval':
    raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
184
185
  # Enables XLA in Session Config. Should not be set for TPU.
  keras_utils.set_config_v2(FLAGS.enable_xla)
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  epochs = FLAGS.num_train_epochs
  train_data_size = input_meta_data['train_data_size']
  steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
  eval_steps = int(
      math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))

  if not strategy:
    raise ValueError('Distribution strategy has not been specified.')
  # Runs customized training loop.
  logging.info('Training using customized training loop TF 2.0 with distrubuted'
               'strategy.')
  use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
201
  trained_model = run_customized_training(
202
203
204
205
206
207
      strategy,
      bert_config,
      input_meta_data,
      FLAGS.model_dir,
      epochs,
      steps_per_epoch,
208
      FLAGS.steps_per_loop,
209
210
211
212
      eval_steps,
      warmup_steps,
      FLAGS.learning_rate,
      FLAGS.init_checkpoint,
213
214
      use_remote_tpu=use_remote_tpu,
      run_eagerly=FLAGS.run_eagerly)
215

216
  if FLAGS.model_export_path:
217
    with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
218
219
      model_saving_utils.export_bert_model(
          FLAGS.model_export_path, model=trained_model)
220
221
  return trained_model

222
223
224
225

def main(_):
  # Users should always run this script under TF 2.x
  assert tf.version.VERSION.startswith('2.')
226

227
228
229
230
231
232
233
234
235
236
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'

  strategy = None
  if FLAGS.strategy_type == 'mirror':
    strategy = tf.distribute.MirroredStrategy()
  elif FLAGS.strategy_type == 'tpu':
237
238
    # Initialize TPU System.
    cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
239
    strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
240
241
242
  else:
    raise ValueError('The distribution strategy type is not supported: %s' %
                     FLAGS.strategy_type)
243
244
245
246
247
248
  run_bert(strategy, input_meta_data)


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('input_meta_data_path')
249
  flags.mark_flag_as_required('model_dir')
250
  app.run(main)