run_classifier.py 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""BERT classification or regression finetuning runner in TF 2.x."""
16
17
18
19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
import functools
21
22
import json
import math
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
23
import os
24

Hongkun Yu's avatar
Hongkun Yu committed
25
# Import libraries
26
27
28
from absl import app
from absl import flags
from absl import logging
Le Hou's avatar
Le Hou committed
29
import gin
30
import tensorflow as tf
31
from official.modeling import performance
32
from official.nlp import optimization
33
from official.nlp.bert import bert_models
34
from official.nlp.bert import common_flags
35
from official.nlp.bert import configs as bert_configs
36
37
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
38
from official.utils.misc import distribution_utils
39
from official.utils.misc import keras_utils
40
41

flags.DEFINE_enum(
Hongkun Yu's avatar
Hongkun Yu committed
42
43
    'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
    'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
44
45
    'trains the model and evaluates in the meantime. '
    '`export_only`: will take the latest checkpoint inside '
Hongkun Yu's avatar
Hongkun Yu committed
46
47
    'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
    'restores the model to output predictions on the test set.')
48
49
50
51
52
53
54
55
flags.DEFINE_string('train_data_path', None,
                    'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
                    'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
Hongkun Yu's avatar
Hongkun Yu committed
56
57
flags.DEFINE_string('predict_checkpoint_path', None,
                    'Path to the checkpoint for predictions.')
Tianqi Liu's avatar
Tianqi Liu committed
58
59
60
61
62
63
flags.DEFINE_integer(
    'num_eval_per_epoch', 1,
    'Number of evaluations per epoch. The purpose of this flag is to provide '
    'more granular evaluation scores and checkpoints. For example, if original '
    'data has N samples and num_eval_per_epoch is n, then each epoch will be '
    'evaluated every N/n samples.')
64
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
65
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
66
67

common_flags.define_common_bert_flags()
68
69
70

FLAGS = flags.FLAGS

71
72
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}

73

74
def get_loss_fn(num_classes):
75
76
77
78
79
80
81
82
83
84
  """Gets the classification loss function."""

  def classification_loss_fn(labels, logits):
    """Classification loss."""
    labels = tf.squeeze(labels)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    one_hot_labels = tf.one_hot(
        tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(
        tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
85
    return tf.reduce_mean(per_example_loss)
86
87
88
89

  return classification_loss_fn


Tianqi Liu's avatar
Tianqi Liu committed
90
91
92
93
def get_dataset_fn(input_file_pattern,
                   max_seq_length,
                   global_batch_size,
                   is_training,
94
95
                   label_type=tf.int64,
                   include_sample_weights=False):
Hongkun Yu's avatar
Hongkun Yu committed
96
97
98
99
100
101
102
  """Gets a closure to create a dataset."""

  def _dataset_fn(ctx=None):
    """Returns tf.data.Dataset for distributed BERT pretraining."""
    batch_size = ctx.get_per_replica_batch_size(
        global_batch_size) if ctx else global_batch_size
    dataset = input_pipeline.create_classifier_dataset(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
        tf.io.gfile.glob(input_file_pattern),
Hongkun Yu's avatar
Hongkun Yu committed
104
105
106
        max_seq_length,
        batch_size,
        is_training=is_training,
107
        input_pipeline_context=ctx,
108
109
        label_type=label_type,
        include_sample_weights=include_sample_weights)
Hongkun Yu's avatar
Hongkun Yu committed
110
111
112
113
114
    return dataset

  return _dataset_fn


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
115
116
117
118
119
120
121
122
123
124
125
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
126
127
                        train_input_fn,
                        eval_input_fn,
128
                        training_callbacks=True,
129
130
                        custom_callbacks=None,
                        custom_metrics=None):
131
132
  """Run BERT classifier training using low-level API."""
  max_seq_length = input_meta_data['max_seq_length']
133
134
  num_classes = input_meta_data.get('num_labels', 1)
  is_regression = num_classes == 1
135
136

  def _get_classifier_model():
137
    """Gets a classifier model."""
138
    classifier_model, core_model = (
139
140
141
142
        bert_models.classifier_model(
            bert_config,
            num_classes,
            max_seq_length,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
143
144
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable))
Hongkun Yu's avatar
Hongkun Yu committed
145
146
147
148
    optimizer = optimization.create_optimizer(initial_lr,
                                              steps_per_epoch * epochs,
                                              warmup_steps, FLAGS.end_lr,
                                              FLAGS.optimizer_type)
149
150
151
152
    classifier_model.optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=common_flags.use_float16(),
        use_graph_rewrite=common_flags.use_graph_rewrite())
153
154
    return classifier_model, core_model

155
156
157
158
159
160
  # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
  # from the dataset) to compute weighted loss, as used for the regression
  # tasks. The classification tasks, using the custom get_loss_fn don't accept
  # sample weights though.
  loss_fn = (tf.keras.losses.MeanSquaredError() if is_regression
             else get_loss_fn(num_classes))
161
162
163

  # Defines evaluation metrics function, which will create metrics in the
  # correct device and strategy scope.
164
165
166
  if custom_metrics:
    metric_fn = custom_metrics
  elif is_regression:
Tianqi Liu's avatar
Tianqi Liu committed
167
168
169
170
    metric_fn = functools.partial(
        tf.keras.metrics.MeanSquaredError,
        'mean_squared_error',
        dtype=tf.float32)
171
  else:
Tianqi Liu's avatar
Tianqi Liu committed
172
173
174
175
    metric_fn = functools.partial(
        tf.keras.metrics.SparseCategoricalAccuracy,
        'accuracy',
        dtype=tf.float32)
176
177
178

  # Start training using Keras compile/fit API.
  logging.info('Training using TF 2.x Keras compile/fit API with '
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
179
               'distribution strategy.')
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  return run_keras_compile_fit(
      model_dir,
      strategy,
      _get_classifier_model,
      train_input_fn,
      eval_input_fn,
      loss_fn,
      metric_fn,
      init_checkpoint,
      epochs,
      steps_per_epoch,
      steps_per_loop,
      eval_steps,
      training_callbacks=training_callbacks,
      custom_callbacks=custom_callbacks)
195
196


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
197
198
199
200
201
202
203
204
205
206
def run_keras_compile_fit(model_dir,
                          strategy,
                          model_fn,
                          train_input_fn,
                          eval_input_fn,
                          loss_fn,
                          metric_fn,
                          init_checkpoint,
                          epochs,
                          steps_per_epoch,
Hongkun Yu's avatar
Hongkun Yu committed
207
                          steps_per_loop,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
208
                          eval_steps,
209
                          training_callbacks=True,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
210
211
212
213
214
                          custom_callbacks=None):
  """Runs BERT classifier model using Keras compile/fit API."""

  with strategy.scope():
    training_dataset = train_input_fn()
Le Hou's avatar
Le Hou committed
215
    evaluation_dataset = eval_input_fn() if eval_input_fn else None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
216
217
218
219
220
221
222
    bert_model, sub_model = model_fn()
    optimizer = bert_model.optimizer

    if init_checkpoint:
      checkpoint = tf.train.Checkpoint(model=sub_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()

223
224
    if not isinstance(metric_fn, (list, tuple)):
      metric_fn = [metric_fn]
Hongkun Yu's avatar
Hongkun Yu committed
225
226
227
    bert_model.compile(
        optimizer=optimizer,
        loss=loss_fn,
228
        metrics=[fn() for fn in metric_fn],
Hongkun Yu's avatar
Hongkun Yu committed
229
        experimental_steps_per_execution=steps_per_loop)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
230

231
232
    summary_dir = os.path.join(model_dir, 'summaries')
    summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
Hongkun Yu's avatar
Hongkun Yu committed
233
234
235
236
237
238
239
240
    checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=None,
        step_counter=optimizer.iterations,
        checkpoint_interval=0)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
241

242
243
244
245
246
    if training_callbacks:
      if custom_callbacks is not None:
        custom_callbacks += [summary_callback, checkpoint_callback]
      else:
        custom_callbacks = [summary_callback, checkpoint_callback]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
247

248
    history = bert_model.fit(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
249
250
251
252
253
254
        x=training_dataset,
        validation_data=evaluation_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_steps=eval_steps,
        callbacks=custom_callbacks)
255
256
257
258
259
260
    stats = {'total_training_steps': steps_per_epoch * epochs}
    if 'loss' in history.history:
      stats['train_loss'] = history.history['loss'][-1]
    if 'val_accuracy' in history.history:
      stats['eval_metrics'] = history.history['val_accuracy'][-1]
    return bert_model, stats
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
261
262


Hongkun Yu's avatar
Hongkun Yu committed
263
264
265
def get_predictions_and_labels(strategy,
                               trained_model,
                               eval_input_fn,
266
                               is_regression=False,
Hongkun Yu's avatar
Hongkun Yu committed
267
                               return_probs=False):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
268
269
270
271
272
273
274
275
276
  """Obtains predictions of trained model on evaluation data.

  Note that list of labels is returned along with the predictions because the
  order changes on distributing dataset over TPU pods.

  Args:
    strategy: Distribution strategy.
    trained_model: Trained model with preloaded weights.
    eval_input_fn: Input function for evaluation data.
277
    is_regression: Whether it is a regression task.
Hongkun Yu's avatar
Hongkun Yu committed
278
    return_probs: Whether to return probabilities of classes.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
279
280
281
282
283
284
285
286
287
288
289
290
291

  Returns:
    predictions: List of predictions.
    labels: List of gold labels corresponding to predictions.
  """

  @tf.function
  def test_step(iterator):
    """Computes predictions on distributed devices."""

    def _test_step_fn(inputs):
      """Replicated predictions."""
      inputs, labels = inputs
Hongkun Yu's avatar
Hongkun Yu committed
292
      logits = trained_model(inputs, training=False)
293
      if not is_regression:
294
295
296
297
        probabilities = tf.nn.softmax(logits)
        return probabilities, labels
      else:
        return logits, labels
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
298

Hongkun Yu's avatar
Hongkun Yu committed
299
    outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
300
301
302
303
304
305
306
307
308
    # outputs: current batch logits as a tuple of shard logits
    outputs = tf.nest.map_structure(strategy.experimental_local_results,
                                    outputs)
    labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
    return outputs, labels

  def _run_evaluation(test_iterator):
    """Runs evaluation steps."""
    preds, golds = list(), list()
Hongkun Yu's avatar
Hongkun Yu committed
309
310
311
312
313
314
315
316
317
318
319
320
    try:
      with tf.experimental.async_scope():
        while True:
          probabilities, labels = test_step(test_iterator)
          for cur_probs, cur_labels in zip(probabilities, labels):
            if return_probs:
              preds.extend(cur_probs.numpy().tolist())
            else:
              preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
            golds.extend(cur_labels.numpy().tolist())
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
321
322
323
324
325
326
327
328
329
    return preds, golds

  test_iter = iter(
      strategy.experimental_distribute_datasets_from_function(eval_input_fn))
  predictions, labels = _run_evaluation(test_iter)

  return predictions, labels


Hongkun Yu's avatar
Hongkun Yu committed
330
331
def export_classifier(model_export_path, input_meta_data, bert_config,
                      model_dir):
332
333
334
335
336
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
337
338
339
    bert_config: Bert configuration file to define core bert layers.
    model_dir: The directory where the model weights and training/evaluation
      summaries are stored.
340
341
342
343
344
345

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
346
347
  if not model_dir:
    raise ValueError('Export path is not specified: %s' % model_dir)
348

Zongwei Zhou's avatar
Zongwei Zhou committed
349
350
  # Export uses float32 for now, even if training uses mixed precision.
  tf.keras.mixed_precision.experimental.set_policy('float32')
351
  classifier_model = bert_models.classifier_model(
352
353
354
355
      bert_config,
      input_meta_data.get('num_labels', 1),
      hub_module_url=FLAGS.hub_module_url,
      hub_module_trainable=False)[0]
356

357
  model_saving_utils.export_bert_model(
Hongkun Yu's avatar
Hongkun Yu committed
358
      model_export_path, model=classifier_model, checkpoint_dir=model_dir)
359
360


Hongkun Yu's avatar
Hongkun Yu committed
361
362
def run_bert(strategy,
             input_meta_data,
363
             model_config,
Hongkun Yu's avatar
Hongkun Yu committed
364
             train_input_fn=None,
Le Hou's avatar
Le Hou committed
365
             eval_input_fn=None,
366
             init_checkpoint=None,
367
368
             custom_callbacks=None,
             custom_metrics=None):
369
  """Run BERT training."""
370
  # Enables XLA in Session Config. Should not be set for TPU.
371
  keras_utils.set_session_config(FLAGS.enable_xla)
372
  performance.set_mixed_precision_policy(common_flags.dtype())
373

Tianqi Liu's avatar
Tianqi Liu committed
374
375
376
  epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
  train_data_size = (
      input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
377
378
379
380
381
382
383
  steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
  eval_steps = int(
      math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))

  if not strategy:
    raise ValueError('Distribution strategy has not been specified.')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
384

385
386
387
  if not custom_callbacks:
    custom_callbacks = []

388
  if FLAGS.log_steps:
Hongkun Yu's avatar
Hongkun Yu committed
389
390
391
392
393
    custom_callbacks.append(
        keras_utils.TimeHistory(
            batch_size=FLAGS.train_batch_size,
            log_steps=FLAGS.log_steps,
            logdir=FLAGS.model_dir))
394

395
  trained_model, _ = run_bert_classifier(
396
      strategy,
397
      model_config,
398
399
400
401
      input_meta_data,
      FLAGS.model_dir,
      epochs,
      steps_per_epoch,
402
      FLAGS.steps_per_loop,
403
404
405
      eval_steps,
      warmup_steps,
      FLAGS.learning_rate,
Le Hou's avatar
Le Hou committed
406
      init_checkpoint or FLAGS.init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
407
408
      train_input_fn,
      eval_input_fn,
409
410
      custom_callbacks=custom_callbacks,
      custom_metrics=custom_metrics)
411

412
  if FLAGS.model_export_path:
413
    model_saving_utils.export_bert_model(
Hongkun Yu's avatar
Hongkun Yu committed
414
        FLAGS.model_export_path, model=trained_model)
415
416
  return trained_model

417

418
def custom_main(custom_callbacks=None, custom_metrics=None):
419
  """Run classification or regression.
420

421
422
  Args:
    custom_callbacks: list of tf.keras.Callbacks passed to training loop.
423
    custom_metrics: list of metrics passed to the training loop.
424
  """
Le Hou's avatar
Le Hou committed
425
426
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

427
428
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))
429
  label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
430
  include_sample_weights = input_meta_data.get('has_sample_weights', False)
431
432
433
434

  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'

Hongkun Yu's avatar
Hongkun Yu committed
435
436
437
438
439
440
441
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.mode == 'export_only':
    export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
                      FLAGS.model_dir)
    return

442
443
444
445
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
Hongkun Yu's avatar
Hongkun Yu committed
446
  eval_input_fn = get_dataset_fn(
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
447
      FLAGS.eval_data_path,
Hongkun Yu's avatar
Hongkun Yu committed
448
      input_meta_data['max_seq_length'],
Hongkun Yu's avatar
Hongkun Yu committed
449
      FLAGS.eval_batch_size,
450
      is_training=False,
451
452
      label_type=label_type,
      include_sample_weights=include_sample_weights)
Hongkun Yu's avatar
Hongkun Yu committed
453

Hongkun Yu's avatar
Hongkun Yu committed
454
  if FLAGS.mode == 'predict':
455
    num_labels = input_meta_data.get('num_labels', 1)
Hongkun Yu's avatar
Hongkun Yu committed
456
457
    with strategy.scope():
      classifier_model = bert_models.classifier_model(
458
          bert_config, num_labels)[0]
Hongkun Yu's avatar
Hongkun Yu committed
459
460
461
462
463
464
465
466
467
468
      checkpoint = tf.train.Checkpoint(model=classifier_model)
      latest_checkpoint_file = (
          FLAGS.predict_checkpoint_path or
          tf.train.latest_checkpoint(FLAGS.model_dir))
      assert latest_checkpoint_file
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', latest_checkpoint_file)
      checkpoint.restore(
          latest_checkpoint_file).assert_existing_objects_matched()
      preds, _ = get_predictions_and_labels(
469
470
471
472
473
          strategy,
          classifier_model,
          eval_input_fn,
          is_regression=(num_labels == 1),
          return_probs=True)
Hongkun Yu's avatar
Hongkun Yu committed
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
    with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
      logging.info('***** Predict results *****')
      for probabilities in preds:
        output_line = '\t'.join(
            str(class_probability)
            for class_probability in probabilities) + '\n'
        writer.write(output_line)
    return

  if FLAGS.mode != 'train_and_eval':
    raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
  train_input_fn = get_dataset_fn(
      FLAGS.train_data_path,
      input_meta_data['max_seq_length'],
      FLAGS.train_batch_size,
490
      is_training=True,
491
492
      label_type=label_type,
      include_sample_weights=include_sample_weights)
Hongkun Yu's avatar
Hongkun Yu committed
493
494
495
496
497
498
  run_bert(
      strategy,
      input_meta_data,
      bert_config,
      train_input_fn,
      eval_input_fn,
499
500
      custom_callbacks=custom_callbacks,
      custom_metrics=custom_metrics)
501
502
503


def main(_):
504
  custom_main(custom_callbacks=None, custom_metrics=None)
505
506
507
508
509


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('input_meta_data_path')
510
  flags.mark_flag_as_required('model_dir')
511
  app.run(main)