"test/algo/compression/v2/test_pruning_wrapper.py" did not exist on "1a3c019afdb64800063b00037af53c7c5be97b37"
run_classifier.py 16.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""BERT classification or regression finetuning runner in TF 2.x."""
16
17
18
19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
import functools
21
22
import json
import math
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
23
import os
24
25
26
27

from absl import app
from absl import flags
from absl import logging
Le Hou's avatar
Le Hou committed
28
import gin
29
import tensorflow as tf
30
from official.modeling import performance
31
from official.nlp import optimization
32
from official.nlp.bert import bert_models
33
from official.nlp.bert import common_flags
34
from official.nlp.bert import configs as bert_configs
35
36
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
37
from official.utils.misc import distribution_utils
38
from official.utils.misc import keras_utils
39
40

flags.DEFINE_enum(
Hongkun Yu's avatar
Hongkun Yu committed
41
42
    'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
    'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
43
44
    'trains the model and evaluates in the meantime. '
    '`export_only`: will take the latest checkpoint inside '
Hongkun Yu's avatar
Hongkun Yu committed
45
46
    'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
    'restores the model to output predictions on the test set.')
47
48
49
50
51
52
53
54
flags.DEFINE_string('train_data_path', None,
                    'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
                    'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
Hongkun Yu's avatar
Hongkun Yu committed
55
56
flags.DEFINE_string('predict_checkpoint_path', None,
                    'Path to the checkpoint for predictions.')
57
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
58
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
59
60

common_flags.define_common_bert_flags()
61
62
63

FLAGS = flags.FLAGS

64
65
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}

66

67
def get_loss_fn(num_classes):
68
69
70
71
72
73
74
75
76
77
  """Gets the classification loss function."""

  def classification_loss_fn(labels, logits):
    """Classification loss."""
    labels = tf.squeeze(labels)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    one_hot_labels = tf.one_hot(
        tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(
        tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
78
    return tf.reduce_mean(per_example_loss)
79
80
81
82

  return classification_loss_fn


83
84
85
86
87
88
89
90
91
92
93
94
def get_regression_loss_fn():
  """Gets the regression loss function."""

  def regression_loss_fn(labels, logits):
    """Regression loss."""
    labels = tf.cast(labels, dtype=tf.float32)
    per_example_loss = tf.math.squared_difference(labels, logits)
    return tf.reduce_mean(per_example_loss)

  return regression_loss_fn


Hongkun Yu's avatar
Hongkun Yu committed
95
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
96
                   is_training, label_type=tf.int64):
Hongkun Yu's avatar
Hongkun Yu committed
97
98
99
100
101
102
103
  """Gets a closure to create a dataset."""

  def _dataset_fn(ctx=None):
    """Returns tf.data.Dataset for distributed BERT pretraining."""
    batch_size = ctx.get_per_replica_batch_size(
        global_batch_size) if ctx else global_batch_size
    dataset = input_pipeline.create_classifier_dataset(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
104
        tf.io.gfile.glob(input_file_pattern),
Hongkun Yu's avatar
Hongkun Yu committed
105
106
107
        max_seq_length,
        batch_size,
        is_training=is_training,
108
109
        input_pipeline_context=ctx,
        label_type=label_type)
Hongkun Yu's avatar
Hongkun Yu committed
110
111
112
113
114
    return dataset

  return _dataset_fn


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
115
116
117
118
119
120
121
122
123
124
125
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
126
127
                        train_input_fn,
                        eval_input_fn,
128
129
                        training_callbacks=True,
                        custom_callbacks=None):
130
131
  """Run BERT classifier training using low-level API."""
  max_seq_length = input_meta_data['max_seq_length']
132
133
  num_classes = input_meta_data.get('num_labels', 1)
  is_regression = num_classes == 1
134
135

  def _get_classifier_model():
136
    """Gets a classifier model."""
137
    classifier_model, core_model = (
138
139
140
141
        bert_models.classifier_model(
            bert_config,
            num_classes,
            max_seq_length,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
142
143
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable))
Hongkun Yu's avatar
Hongkun Yu committed
144
145
146
147
    optimizer = optimization.create_optimizer(initial_lr,
                                              steps_per_epoch * epochs,
                                              warmup_steps, FLAGS.end_lr,
                                              FLAGS.optimizer_type)
148
149
150
151
    classifier_model.optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=common_flags.use_float16(),
        use_graph_rewrite=common_flags.use_graph_rewrite())
152
153
    return classifier_model, core_model

154
155
  loss_fn = (get_regression_loss_fn() if is_regression
             else get_loss_fn(num_classes))
156
157
158

  # Defines evaluation metrics function, which will create metrics in the
  # correct device and strategy scope.
159
160
161
162
163
164
  if is_regression:
    metric_fn = functools.partial(tf.keras.metrics.MeanSquaredError,
                                  'mean_squared_error', dtype=tf.float32)
  else:
    metric_fn = functools.partial(tf.keras.metrics.SparseCategoricalAccuracy,
                                  'accuracy', dtype=tf.float32)
165
166
167

  # Start training using Keras compile/fit API.
  logging.info('Training using TF 2.x Keras compile/fit API with '
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
168
               'distribution strategy.')
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  return run_keras_compile_fit(
      model_dir,
      strategy,
      _get_classifier_model,
      train_input_fn,
      eval_input_fn,
      loss_fn,
      metric_fn,
      init_checkpoint,
      epochs,
      steps_per_epoch,
      steps_per_loop,
      eval_steps,
      training_callbacks=training_callbacks,
      custom_callbacks=custom_callbacks)
184
185


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
186
187
188
189
190
191
192
193
194
195
def run_keras_compile_fit(model_dir,
                          strategy,
                          model_fn,
                          train_input_fn,
                          eval_input_fn,
                          loss_fn,
                          metric_fn,
                          init_checkpoint,
                          epochs,
                          steps_per_epoch,
Hongkun Yu's avatar
Hongkun Yu committed
196
                          steps_per_loop,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
197
                          eval_steps,
198
                          training_callbacks=True,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
199
200
201
202
203
                          custom_callbacks=None):
  """Runs BERT classifier model using Keras compile/fit API."""

  with strategy.scope():
    training_dataset = train_input_fn()
Le Hou's avatar
Le Hou committed
204
    evaluation_dataset = eval_input_fn() if eval_input_fn else None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
205
206
207
208
209
210
211
    bert_model, sub_model = model_fn()
    optimizer = bert_model.optimizer

    if init_checkpoint:
      checkpoint = tf.train.Checkpoint(model=sub_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()

Hongkun Yu's avatar
Hongkun Yu committed
212
213
214
215
216
    bert_model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=[metric_fn()],
        experimental_steps_per_execution=steps_per_loop)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
217

218
219
    summary_dir = os.path.join(model_dir, 'summaries')
    summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
Hongkun Yu's avatar
Hongkun Yu committed
220
221
222
223
224
225
226
227
    checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=None,
        step_counter=optimizer.iterations,
        checkpoint_interval=0)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
228

229
230
231
232
233
    if training_callbacks:
      if custom_callbacks is not None:
        custom_callbacks += [summary_callback, checkpoint_callback]
      else:
        custom_callbacks = [summary_callback, checkpoint_callback]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
234

235
    history = bert_model.fit(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
236
237
238
239
240
241
        x=training_dataset,
        validation_data=evaluation_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_steps=eval_steps,
        callbacks=custom_callbacks)
242
243
244
245
246
247
    stats = {'total_training_steps': steps_per_epoch * epochs}
    if 'loss' in history.history:
      stats['train_loss'] = history.history['loss'][-1]
    if 'val_accuracy' in history.history:
      stats['eval_metrics'] = history.history['val_accuracy'][-1]
    return bert_model, stats
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
248
249


Hongkun Yu's avatar
Hongkun Yu committed
250
251
252
253
def get_predictions_and_labels(strategy,
                               trained_model,
                               eval_input_fn,
                               return_probs=False):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
254
255
256
257
258
259
260
261
262
  """Obtains predictions of trained model on evaluation data.

  Note that list of labels is returned along with the predictions because the
  order changes on distributing dataset over TPU pods.

  Args:
    strategy: Distribution strategy.
    trained_model: Trained model with preloaded weights.
    eval_input_fn: Input function for evaluation data.
Hongkun Yu's avatar
Hongkun Yu committed
263
    return_probs: Whether to return probabilities of classes.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
264
265
266
267
268
269
270
271
272
273
274
275
276

  Returns:
    predictions: List of predictions.
    labels: List of gold labels corresponding to predictions.
  """

  @tf.function
  def test_step(iterator):
    """Computes predictions on distributed devices."""

    def _test_step_fn(inputs):
      """Replicated predictions."""
      inputs, labels = inputs
Hongkun Yu's avatar
Hongkun Yu committed
277
278
279
      logits = trained_model(inputs, training=False)
      probabilities = tf.nn.softmax(logits)
      return probabilities, labels
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
280

Hongkun Yu's avatar
Hongkun Yu committed
281
    outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
282
283
284
285
286
287
288
289
290
    # outputs: current batch logits as a tuple of shard logits
    outputs = tf.nest.map_structure(strategy.experimental_local_results,
                                    outputs)
    labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
    return outputs, labels

  def _run_evaluation(test_iterator):
    """Runs evaluation steps."""
    preds, golds = list(), list()
Hongkun Yu's avatar
Hongkun Yu committed
291
292
293
294
295
296
297
298
299
300
301
302
    try:
      with tf.experimental.async_scope():
        while True:
          probabilities, labels = test_step(test_iterator)
          for cur_probs, cur_labels in zip(probabilities, labels):
            if return_probs:
              preds.extend(cur_probs.numpy().tolist())
            else:
              preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
            golds.extend(cur_labels.numpy().tolist())
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
303
304
305
306
307
308
309
310
311
    return preds, golds

  test_iter = iter(
      strategy.experimental_distribute_datasets_from_function(eval_input_fn))
  predictions, labels = _run_evaluation(test_iter)

  return predictions, labels


Hongkun Yu's avatar
Hongkun Yu committed
312
313
def export_classifier(model_export_path, input_meta_data, bert_config,
                      model_dir):
314
315
316
317
318
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
319
320
321
    bert_config: Bert configuration file to define core bert layers.
    model_dir: The directory where the model weights and training/evaluation
      summaries are stored.
322
323
324
325
326
327

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
328
329
  if not model_dir:
    raise ValueError('Export path is not specified: %s' % model_dir)
330

Zongwei Zhou's avatar
Zongwei Zhou committed
331
332
  # Export uses float32 for now, even if training uses mixed precision.
  tf.keras.mixed_precision.experimental.set_policy('float32')
333
  classifier_model = bert_models.classifier_model(
334
      bert_config, input_meta_data.get('num_labels', 1))[0]
335

336
  model_saving_utils.export_bert_model(
Hongkun Yu's avatar
Hongkun Yu committed
337
      model_export_path, model=classifier_model, checkpoint_dir=model_dir)
338
339


Hongkun Yu's avatar
Hongkun Yu committed
340
341
def run_bert(strategy,
             input_meta_data,
342
             model_config,
Hongkun Yu's avatar
Hongkun Yu committed
343
             train_input_fn=None,
Le Hou's avatar
Le Hou committed
344
             eval_input_fn=None,
345
346
             init_checkpoint=None,
             custom_callbacks=None):
347
  """Run BERT training."""
348
  # Enables XLA in Session Config. Should not be set for TPU.
349
  keras_utils.set_session_config(FLAGS.enable_xla)
350
  performance.set_mixed_precision_policy(common_flags.dtype())
351
352
353
354
355
356
357
358
359
360

  epochs = FLAGS.num_train_epochs
  train_data_size = input_meta_data['train_data_size']
  steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
  eval_steps = int(
      math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))

  if not strategy:
    raise ValueError('Distribution strategy has not been specified.')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
361

362
363
364
  if not custom_callbacks:
    custom_callbacks = []

365
  if FLAGS.log_steps:
Hongkun Yu's avatar
Hongkun Yu committed
366
367
368
369
370
    custom_callbacks.append(
        keras_utils.TimeHistory(
            batch_size=FLAGS.train_batch_size,
            log_steps=FLAGS.log_steps,
            logdir=FLAGS.model_dir))
371

372
  trained_model, _ = run_bert_classifier(
373
      strategy,
374
      model_config,
375
376
377
378
      input_meta_data,
      FLAGS.model_dir,
      epochs,
      steps_per_epoch,
379
      FLAGS.steps_per_loop,
380
381
382
      eval_steps,
      warmup_steps,
      FLAGS.learning_rate,
Le Hou's avatar
Le Hou committed
383
      init_checkpoint or FLAGS.init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
384
385
      train_input_fn,
      eval_input_fn,
386
      custom_callbacks=custom_callbacks)
387

388
  if FLAGS.model_export_path:
389
    model_saving_utils.export_bert_model(
Hongkun Yu's avatar
Hongkun Yu committed
390
        FLAGS.model_export_path, model=trained_model)
391
392
  return trained_model

393

394
def custom_main(custom_callbacks=None):
395
  """Run classification or regression.
396

397
398
399
  Args:
    custom_callbacks: list of tf.keras.Callbacks passed to training loop.
  """
Le Hou's avatar
Le Hou committed
400
401
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

402
403
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))
404
  label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
405
406
407
408

  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'

Hongkun Yu's avatar
Hongkun Yu committed
409
410
411
412
413
414
415
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.mode == 'export_only':
    export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
                      FLAGS.model_dir)
    return

416
417
418
419
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
Hongkun Yu's avatar
Hongkun Yu committed
420
  eval_input_fn = get_dataset_fn(
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
421
      FLAGS.eval_data_path,
Hongkun Yu's avatar
Hongkun Yu committed
422
      input_meta_data['max_seq_length'],
Hongkun Yu's avatar
Hongkun Yu committed
423
      FLAGS.eval_batch_size,
424
425
      is_training=False,
      label_type=label_type)
Hongkun Yu's avatar
Hongkun Yu committed
426

Hongkun Yu's avatar
Hongkun Yu committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
  if FLAGS.mode == 'predict':
    with strategy.scope():
      classifier_model = bert_models.classifier_model(
          bert_config, input_meta_data['num_labels'])[0]
      checkpoint = tf.train.Checkpoint(model=classifier_model)
      latest_checkpoint_file = (
          FLAGS.predict_checkpoint_path or
          tf.train.latest_checkpoint(FLAGS.model_dir))
      assert latest_checkpoint_file
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', latest_checkpoint_file)
      checkpoint.restore(
          latest_checkpoint_file).assert_existing_objects_matched()
      preds, _ = get_predictions_and_labels(
          strategy, classifier_model, eval_input_fn, return_probs=True)
    output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
    with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
      logging.info('***** Predict results *****')
      for probabilities in preds:
        output_line = '\t'.join(
            str(class_probability)
            for class_probability in probabilities) + '\n'
        writer.write(output_line)
    return

  if FLAGS.mode != 'train_and_eval':
    raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
  train_input_fn = get_dataset_fn(
      FLAGS.train_data_path,
      input_meta_data['max_seq_length'],
      FLAGS.train_batch_size,
458
459
      is_training=True,
      label_type=label_type)
Hongkun Yu's avatar
Hongkun Yu committed
460
461
462
463
464
465
466
  run_bert(
      strategy,
      input_meta_data,
      bert_config,
      train_input_fn,
      eval_input_fn,
      custom_callbacks=custom_callbacks)
467
468
469
470


def main(_):
  custom_main(custom_callbacks=None)
471
472
473
474
475


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('input_meta_data_path')
476
  flags.mark_flag_as_required('model_dir')
477
  app.run(main)