run_classifier.py 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""BERT classification finetuning runner in TF 2.x."""
16
17
18
19
20
21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import math
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
import os
23
24
25
26

from absl import app
from absl import flags
from absl import logging
Le Hou's avatar
Le Hou committed
27
import gin
28
import tensorflow as tf
29
from official.modeling import performance
30
from official.nlp import optimization
31
from official.nlp.bert import bert_models
32
from official.nlp.bert import common_flags
33
from official.nlp.bert import configs as bert_configs
34
35
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
36
from official.utils.misc import distribution_utils
37
from official.utils.misc import keras_utils
38
39

flags.DEFINE_enum(
Hongkun Yu's avatar
Hongkun Yu committed
40
41
    'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
    'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
42
43
    'trains the model and evaluates in the meantime. '
    '`export_only`: will take the latest checkpoint inside '
Hongkun Yu's avatar
Hongkun Yu committed
44
45
    'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
    'restores the model to output predictions on the test set.')
46
47
48
49
50
51
52
53
flags.DEFINE_string('train_data_path', None,
                    'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
                    'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
Hongkun Yu's avatar
Hongkun Yu committed
54
55
flags.DEFINE_string('predict_checkpoint_path', None,
                    'Path to the checkpoint for predictions.')
56
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
57
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
58
59

common_flags.define_common_bert_flags()
60
61
62
63

FLAGS = flags.FLAGS


64
def get_loss_fn(num_classes):
65
66
67
68
69
70
71
72
73
74
  """Gets the classification loss function."""

  def classification_loss_fn(labels, logits):
    """Classification loss."""
    labels = tf.squeeze(labels)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    one_hot_labels = tf.one_hot(
        tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(
        tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
75
    return tf.reduce_mean(per_example_loss)
76
77
78
79

  return classification_loss_fn


Hongkun Yu's avatar
Hongkun Yu committed
80
81
82
83
84
85
86
87
88
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
                   is_training):
  """Gets a closure to create a dataset."""

  def _dataset_fn(ctx=None):
    """Returns tf.data.Dataset for distributed BERT pretraining."""
    batch_size = ctx.get_per_replica_batch_size(
        global_batch_size) if ctx else global_batch_size
    dataset = input_pipeline.create_classifier_dataset(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
89
        tf.io.gfile.glob(input_file_pattern),
Hongkun Yu's avatar
Hongkun Yu committed
90
91
92
93
94
95
96
97
98
        max_seq_length,
        batch_size,
        is_training=is_training,
        input_pipeline_context=ctx)
    return dataset

  return _dataset_fn


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
99
100
101
102
103
104
105
106
107
108
109
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
110
111
                        train_input_fn,
                        eval_input_fn,
112
113
                        training_callbacks=True,
                        custom_callbacks=None):
114
115
116
117
118
  """Run BERT classifier training using low-level API."""
  max_seq_length = input_meta_data['max_seq_length']
  num_classes = input_meta_data['num_labels']

  def _get_classifier_model():
119
    """Gets a classifier model."""
120
    classifier_model, core_model = (
121
122
123
124
        bert_models.classifier_model(
            bert_config,
            num_classes,
            max_seq_length,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
125
126
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable))
Hongkun Yu's avatar
Hongkun Yu committed
127
128
129
130
    optimizer = optimization.create_optimizer(initial_lr,
                                              steps_per_epoch * epochs,
                                              warmup_steps, FLAGS.end_lr,
                                              FLAGS.optimizer_type)
131
132
133
134
    classifier_model.optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=common_flags.use_float16(),
        use_graph_rewrite=common_flags.use_graph_rewrite())
135
136
    return classifier_model, core_model

137
  loss_fn = get_loss_fn(num_classes)
138
139
140
141
142

  # Defines evaluation metrics function, which will create metrics in the
  # correct device and strategy scope.
  def metric_fn():
    return tf.keras.metrics.SparseCategoricalAccuracy(
143
144
145
146
        'accuracy', dtype=tf.float32)

  # Start training using Keras compile/fit API.
  logging.info('Training using TF 2.x Keras compile/fit API with '
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
147
               'distribution strategy.')
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
  return run_keras_compile_fit(
      model_dir,
      strategy,
      _get_classifier_model,
      train_input_fn,
      eval_input_fn,
      loss_fn,
      metric_fn,
      init_checkpoint,
      epochs,
      steps_per_epoch,
      steps_per_loop,
      eval_steps,
      training_callbacks=training_callbacks,
      custom_callbacks=custom_callbacks)
163
164


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
165
166
167
168
169
170
171
172
173
174
def run_keras_compile_fit(model_dir,
                          strategy,
                          model_fn,
                          train_input_fn,
                          eval_input_fn,
                          loss_fn,
                          metric_fn,
                          init_checkpoint,
                          epochs,
                          steps_per_epoch,
Hongkun Yu's avatar
Hongkun Yu committed
175
                          steps_per_loop,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
176
                          eval_steps,
177
                          training_callbacks=True,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
178
179
180
181
182
                          custom_callbacks=None):
  """Runs BERT classifier model using Keras compile/fit API."""

  with strategy.scope():
    training_dataset = train_input_fn()
Le Hou's avatar
Le Hou committed
183
    evaluation_dataset = eval_input_fn() if eval_input_fn else None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
184
185
186
187
188
189
190
    bert_model, sub_model = model_fn()
    optimizer = bert_model.optimizer

    if init_checkpoint:
      checkpoint = tf.train.Checkpoint(model=sub_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()

Hongkun Yu's avatar
Hongkun Yu committed
191
192
193
194
195
    bert_model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=[metric_fn()],
        experimental_steps_per_execution=steps_per_loop)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
196

197
198
    summary_dir = os.path.join(model_dir, 'summaries')
    summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
Hongkun Yu's avatar
Hongkun Yu committed
199
200
201
202
203
204
205
206
    checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=None,
        step_counter=optimizer.iterations,
        checkpoint_interval=0)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
207

208
209
210
211
212
    if training_callbacks:
      if custom_callbacks is not None:
        custom_callbacks += [summary_callback, checkpoint_callback]
      else:
        custom_callbacks = [summary_callback, checkpoint_callback]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
213

214
    history = bert_model.fit(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
215
216
217
218
219
220
        x=training_dataset,
        validation_data=evaluation_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_steps=eval_steps,
        callbacks=custom_callbacks)
221
222
223
224
225
226
    stats = {'total_training_steps': steps_per_epoch * epochs}
    if 'loss' in history.history:
      stats['train_loss'] = history.history['loss'][-1]
    if 'val_accuracy' in history.history:
      stats['eval_metrics'] = history.history['val_accuracy'][-1]
    return bert_model, stats
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
227
228


Hongkun Yu's avatar
Hongkun Yu committed
229
230
231
232
def get_predictions_and_labels(strategy,
                               trained_model,
                               eval_input_fn,
                               return_probs=False):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
233
234
235
236
237
238
239
240
241
  """Obtains predictions of trained model on evaluation data.

  Note that list of labels is returned along with the predictions because the
  order changes on distributing dataset over TPU pods.

  Args:
    strategy: Distribution strategy.
    trained_model: Trained model with preloaded weights.
    eval_input_fn: Input function for evaluation data.
Hongkun Yu's avatar
Hongkun Yu committed
242
    return_probs: Whether to return probabilities of classes.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
243
244
245
246
247
248
249
250
251
252
253
254
255

  Returns:
    predictions: List of predictions.
    labels: List of gold labels corresponding to predictions.
  """

  @tf.function
  def test_step(iterator):
    """Computes predictions on distributed devices."""

    def _test_step_fn(inputs):
      """Replicated predictions."""
      inputs, labels = inputs
Hongkun Yu's avatar
Hongkun Yu committed
256
257
258
      logits = trained_model(inputs, training=False)
      probabilities = tf.nn.softmax(logits)
      return probabilities, labels
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
259

Hongkun Yu's avatar
Hongkun Yu committed
260
    outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
261
262
263
264
265
266
267
268
269
    # outputs: current batch logits as a tuple of shard logits
    outputs = tf.nest.map_structure(strategy.experimental_local_results,
                                    outputs)
    labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
    return outputs, labels

  def _run_evaluation(test_iterator):
    """Runs evaluation steps."""
    preds, golds = list(), list()
Hongkun Yu's avatar
Hongkun Yu committed
270
271
272
273
274
275
276
277
278
279
280
281
    try:
      with tf.experimental.async_scope():
        while True:
          probabilities, labels = test_step(test_iterator)
          for cur_probs, cur_labels in zip(probabilities, labels):
            if return_probs:
              preds.extend(cur_probs.numpy().tolist())
            else:
              preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
            golds.extend(cur_labels.numpy().tolist())
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
282
283
284
285
286
287
288
289
290
    return preds, golds

  test_iter = iter(
      strategy.experimental_distribute_datasets_from_function(eval_input_fn))
  predictions, labels = _run_evaluation(test_iter)

  return predictions, labels


Hongkun Yu's avatar
Hongkun Yu committed
291
292
def export_classifier(model_export_path, input_meta_data, bert_config,
                      model_dir):
293
294
295
296
297
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
298
299
300
    bert_config: Bert configuration file to define core bert layers.
    model_dir: The directory where the model weights and training/evaluation
      summaries are stored.
301
302
303
304
305
306

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
307
308
  if not model_dir:
    raise ValueError('Export path is not specified: %s' % model_dir)
309

Zongwei Zhou's avatar
Zongwei Zhou committed
310
311
  # Export uses float32 for now, even if training uses mixed precision.
  tf.keras.mixed_precision.experimental.set_policy('float32')
312
  classifier_model = bert_models.classifier_model(
Hongkun Yu's avatar
Hongkun Yu committed
313
      bert_config, input_meta_data['num_labels'])[0]
314

315
  model_saving_utils.export_bert_model(
Hongkun Yu's avatar
Hongkun Yu committed
316
      model_export_path, model=classifier_model, checkpoint_dir=model_dir)
317
318


Hongkun Yu's avatar
Hongkun Yu committed
319
320
def run_bert(strategy,
             input_meta_data,
321
             model_config,
Hongkun Yu's avatar
Hongkun Yu committed
322
             train_input_fn=None,
Le Hou's avatar
Le Hou committed
323
             eval_input_fn=None,
324
325
             init_checkpoint=None,
             custom_callbacks=None):
326
  """Run BERT training."""
327
  # Enables XLA in Session Config. Should not be set for TPU.
328
  keras_utils.set_session_config(FLAGS.enable_xla)
329
  performance.set_mixed_precision_policy(common_flags.dtype())
330
331
332
333
334
335
336
337
338
339

  epochs = FLAGS.num_train_epochs
  train_data_size = input_meta_data['train_data_size']
  steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
  eval_steps = int(
      math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))

  if not strategy:
    raise ValueError('Distribution strategy has not been specified.')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
340

341
342
343
  if not custom_callbacks:
    custom_callbacks = []

344
  if FLAGS.log_steps:
Hongkun Yu's avatar
Hongkun Yu committed
345
346
347
348
349
    custom_callbacks.append(
        keras_utils.TimeHistory(
            batch_size=FLAGS.train_batch_size,
            log_steps=FLAGS.log_steps,
            logdir=FLAGS.model_dir))
350

351
  trained_model, _ = run_bert_classifier(
352
      strategy,
353
      model_config,
354
355
356
357
      input_meta_data,
      FLAGS.model_dir,
      epochs,
      steps_per_epoch,
358
      FLAGS.steps_per_loop,
359
360
361
      eval_steps,
      warmup_steps,
      FLAGS.learning_rate,
Le Hou's avatar
Le Hou committed
362
      init_checkpoint or FLAGS.init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
363
364
      train_input_fn,
      eval_input_fn,
365
      custom_callbacks=custom_callbacks)
366

367
  if FLAGS.model_export_path:
368
    model_saving_utils.export_bert_model(
Hongkun Yu's avatar
Hongkun Yu committed
369
        FLAGS.model_export_path, model=trained_model)
370
371
  return trained_model

372

373
374
def custom_main(custom_callbacks=None):
  """Run classification.
375

376
377
378
  Args:
    custom_callbacks: list of tf.keras.Callbacks passed to training loop.
  """
Le Hou's avatar
Le Hou committed
379
380
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

381
382
383
384
385
386
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'

Hongkun Yu's avatar
Hongkun Yu committed
387
388
389
390
391
392
393
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.mode == 'export_only':
    export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
                      FLAGS.model_dir)
    return

394
395
396
397
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
Hongkun Yu's avatar
Hongkun Yu committed
398
  eval_input_fn = get_dataset_fn(
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
399
      FLAGS.eval_data_path,
Hongkun Yu's avatar
Hongkun Yu committed
400
      input_meta_data['max_seq_length'],
Hongkun Yu's avatar
Hongkun Yu committed
401
402
403
      FLAGS.eval_batch_size,
      is_training=False)

Hongkun Yu's avatar
Hongkun Yu committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
  if FLAGS.mode == 'predict':
    with strategy.scope():
      classifier_model = bert_models.classifier_model(
          bert_config, input_meta_data['num_labels'])[0]
      checkpoint = tf.train.Checkpoint(model=classifier_model)
      latest_checkpoint_file = (
          FLAGS.predict_checkpoint_path or
          tf.train.latest_checkpoint(FLAGS.model_dir))
      assert latest_checkpoint_file
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', latest_checkpoint_file)
      checkpoint.restore(
          latest_checkpoint_file).assert_existing_objects_matched()
      preds, _ = get_predictions_and_labels(
          strategy, classifier_model, eval_input_fn, return_probs=True)
    output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
    with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
      logging.info('***** Predict results *****')
      for probabilities in preds:
        output_line = '\t'.join(
            str(class_probability)
            for class_probability in probabilities) + '\n'
        writer.write(output_line)
    return

  if FLAGS.mode != 'train_and_eval':
    raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
  train_input_fn = get_dataset_fn(
      FLAGS.train_data_path,
      input_meta_data['max_seq_length'],
      FLAGS.train_batch_size,
      is_training=True)
  run_bert(
      strategy,
      input_meta_data,
      bert_config,
      train_input_fn,
      eval_input_fn,
      custom_callbacks=custom_callbacks)
443
444
445
446


def main(_):
  custom_main(custom_callbacks=None)
447
448
449
450
451


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('input_meta_data_path')
452
  flags.mark_flag_as_required('model_dir')
453
  app.run(main)