model_training_utils.py 16.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""A light weight utilities to train NLP models."""
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import json
22
23
24
import os

from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
25
import tensorflow as tf
26
from official.utils.misc import distribution_utils
27

28
29
_SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10
30

31
32
33
34
35
36
37
38
39
40

def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
  """Saves model to with provided checkpoint prefix."""

  checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
  saved_path = checkpoint.save(checkpoint_path)
  logging.info('Saving model as TF checkpoint: %s', saved_path)
  return


41
42
43
44
45
def _get_input_iterator(input_fn, strategy):
  """Returns distributed dataset iterator."""
  # When training with TPU pods, datasets needs to be cloned across
  # workers. Since Dataset instance cannot be cloned in eager mode, we instead
  # pass callable that returns a dataset.
Hongkun Yu's avatar
Hongkun Yu committed
46
47
48
49
  if not callable(input_fn):
    raise ValueError('`input_fn` should be a closure that returns a dataset.')
  iterator = iter(
      strategy.experimental_distribute_datasets_from_function(input_fn))
50
51
52
  return iterator


53
54
55
56
57
def _float_metric_value(metric):
  """Gets the value of a float-value keras metric."""
  return metric.result().numpy().astype(float)


58
def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
59
  """Calculates steps to run on device."""
60
61
62
  if steps_per_loop <= 0:
    raise ValueError('steps_per_loop should be positive integer.')
  if steps_per_loop == 1:
63
64
65
66
67
68
69
70
    return steps_per_loop
  remainder_in_epoch = current_step % steps_per_epoch
  if remainder_in_epoch != 0:
    return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
  else:
    return steps_per_loop


71
def write_txt_summary(training_summary, summary_dir):
72
  """Writes a summary text file to record stats."""
73
  summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
74
75
76
77
78
  with tf.io.gfile.GFile(summary_path, 'wb') as f:
    logging.info('Training Summary: \n%s', str(training_summary))
    f.write(json.dumps(training_summary, indent=4))


79
80
81
82
83
84
85
86
87
88
def run_customized_training_loop(
    # pylint: disable=invalid-name
    _sentinel=None,
    # pylint: enable=invalid-name
    strategy=None,
    model_fn=None,
    loss_fn=None,
    model_dir=None,
    train_input_fn=None,
    steps_per_epoch=None,
89
    steps_per_loop=1,
90
91
92
93
94
    epochs=1,
    eval_input_fn=None,
    eval_steps=None,
    metric_fn=None,
    init_checkpoint=None,
95
    custom_callbacks=None,
Chen Chen's avatar
Chen Chen committed
96
    run_eagerly=False,
Hongkun Yu's avatar
Hongkun Yu committed
97
    sub_model_export_name=None):
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
113
114
115
116
117
118
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
119
120
121
122
123
124
125
126
127
128
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
129
      custom_callbacks: A list of Keras Callbacks objects to run during
130
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
131
        methods are invoked during training.
132
133
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.
Chen Chen's avatar
Chen Chen committed
134
135
136
137
138
      sub_model_export_name: If not None, will export `sub_model` returned by
        `model_fn` into checkpoint files. The name of intermediate checkpoint
        file is {sub_model_export_name}_step_{step}.ckpt and the last
        checkpint's name is {sub_model_export_name}.ckpt;
        if None, `sub_model` will not be exported as checkpoint.
139
140
141
142
143
144
145
146

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
Chen Chen's avatar
Chen Chen committed
147
148
        (4) sub_model_checkpoint_name is specified, but `sub_model` returned
        by `model_fn` is None.
149
150
151
152
153
154
155
156
157
158
159
  """

  if _sentinel is not None:
    raise ValueError('only call `run_customized_training_loop()` '
                     'with named arguments.')

  required_arguments = [
      strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
  ]
  if [arg for arg in required_arguments if arg is None]:
    raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
160
161
162
163
164
165
166
167
                     '`steps_per_loop` and `steps_per_epoch` are required '
                     'parameters.')
  if steps_per_loop > steps_per_epoch:
    logging.error(
        'steps_per_loop: %d is specified to be greater than '
        ' steps_per_epoch: %d, we will use steps_per_epoch as'
        ' steps_per_loop.', steps_per_loop, steps_per_epoch)
    steps_per_loop = steps_per_epoch
168
169
  assert tf.executing_eagerly()

170
171
172
173
174
175
  if run_eagerly:
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      raise ValueError(
          'TPUStrategy should not run eagerly as it heavily replies on graph'
          ' optimization for the distributed system.')

176
177
178
179
180
181
182
183
  if eval_input_fn and (eval_steps is None or metric_fn is None):
    raise ValueError(
        '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
        'is not none.')
  if metric_fn and not callable(metric_fn):
    raise ValueError(
        'if `metric_fn` is specified, metric_fn must be a callable.')

184
185
  total_training_steps = steps_per_epoch * epochs

186
187
  # To reduce unnecessary send/receive input pipeline operation, we place input
  # pipeline ops in worker task.
188
189
190
191
192
193
194
195
196
  train_iterator = _get_input_iterator(train_input_fn, strategy)

  with distribution_utils.get_strategy_scope(strategy):
    # To correctly place the model weights on accelerators,
    # model and optimizer should be created in scope.
    model, sub_model = model_fn()
    if not hasattr(model, 'optimizer'):
      raise ValueError('User should set optimizer attribute to model '
                       'inside `model_fn`.')
Chen Chen's avatar
Chen Chen committed
197
198
199
200
    if sub_model_export_name and sub_model is None:
      raise ValueError('sub_model_export_name is specified as %s, but '
                       'sub_model is None.' % sub_model_export_name)

201
202
203
204
205
206
207
208
209
    optimizer = model.optimizer
    use_float16 = isinstance(
        optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)

    if init_checkpoint:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'initial checkpoint for core model.', init_checkpoint)
      checkpoint = tf.train.Checkpoint(model=sub_model)
Jing Li's avatar
Jing Li committed
210
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
211
212
213
214
215
216
217
218
219
220
221
222
223
      logging.info('Loading from checkpoint file completed')

    train_loss_metric = tf.keras.metrics.Mean(
        'training_loss', dtype=tf.float32)
    eval_metrics = [metric_fn()] if metric_fn else []
    # If evaluation is required, make a copy of metric as it will be used by
    # both train and evaluation.
    train_metrics = [
        metric.__class__.from_config(metric.get_config())
        for metric in eval_metrics
    ]

    # Create summary writers
224
    summary_dir = os.path.join(model_dir, 'summaries')
225
    eval_summary_writer = tf.summary.create_file_writer(
226
        os.path.join(summary_dir, 'eval'))
227
228
229
230
    if steps_per_loop >= _MIN_SUMMARY_STEPS:
      # Only writes summary when the stats are collected sufficiently over
      # enough steps.
      train_summary_writer = tf.summary.create_file_writer(
231
          os.path.join(summary_dir, 'train'))
232
233
234
235
236
237
238
239
240
241
242
243
244
    else:
      train_summary_writer = None

    # Collects training variables.
    training_vars = model.trainable_variables

    def _replicated_step(inputs):
      """Replicated training step."""

      inputs, labels = inputs
      with tf.GradientTape() as tape:
        model_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
245
        if use_float16:
246
          scaled_loss = optimizer.get_scaled_loss(loss)
247

248
249
      if use_float16:
        scaled_grads = tape.gradient(scaled_loss, training_vars)
Hongkun Yu's avatar
Hongkun Yu committed
250
        grads = optimizer.get_unscaled_gradients(scaled_grads)
251
252
      else:
        grads = tape.gradient(loss, training_vars)
Hongkun Yu's avatar
Hongkun Yu committed
253
      optimizer.apply_gradients(zip(grads, training_vars))
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(loss)
      for metric in train_metrics:
        metric.update_state(labels, model_outputs)

    @tf.function
    def train_steps(iterator, steps):
      """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      if not isinstance(steps, tf.Tensor):
        raise ValueError('steps should be an Tensor. Python object may cause '
                         'retracing.')

      for _ in tf.range(steps):
        strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
277

278
279
    def train_single_step(iterator):
      """Performs a distributed training step.
280

281
282
      Args:
        iterator: the distributed iterator of training datasets.
283

284
285
286
287
      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
288

289
290
    def test_step(iterator):
      """Calculates evaluation metrics on distributed devices."""
291

292
293
      def _test_step_fn(inputs):
        """Replicated accuracy calculation."""
294

295
296
297
298
        inputs, labels = inputs
        model_outputs = model(inputs, training=False)
        for metric in eval_metrics:
          metric.update_state(labels, model_outputs)
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
      strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))

    if not run_eagerly:
      train_single_step = tf.function(train_single_step)
      test_step = tf.function(test_step)

    def _run_evaluation(current_training_step, test_iterator):
      """Runs validation steps and aggregate metrics."""
      for _ in range(eval_steps):
        test_step(test_iterator)

      with eval_summary_writer.as_default():
        for metric in eval_metrics + model.metrics:
          metric_value = _float_metric_value(metric)
          logging.info('Step: [%d] Validation %s = %f', current_training_step,
                       metric.name, metric_value)
          tf.summary.scalar(
              metric.name, metric_value, step=current_training_step)
        eval_summary_writer.flush()

    def _run_callbacks_on_batch_begin(batch):
      """Runs custom callbacks at the start of every step."""
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_begin(batch)

327
    def _run_callbacks_on_batch_end(batch, logs):
328
329
330
331
      """Runs custom callbacks at the end of every step."""
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
332
        callback.on_batch_end(batch, logs)
333
334
335

    # Training loop starts here.
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
Chen Chen's avatar
Chen Chen committed
336
337
338
    sub_model_checkpoint = tf.train.Checkpoint(
        model=sub_model) if sub_model_export_name else None

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint_file:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'checkpoint', latest_checkpoint_file)
      checkpoint.restore(latest_checkpoint_file)
      logging.info('Loading from checkpoint file completed')

    current_step = optimizer.iterations.numpy()
    checkpoint_name = 'ctl_step_{step}.ckpt'

    while current_step < total_training_steps:
      # Training loss/metric are taking average over steps inside micro
      # training loop. We reset the their values before each round.
      train_loss_metric.reset_states()
      for metric in train_metrics + model.metrics:
        metric.reset_states()

      _run_callbacks_on_batch_begin(current_step)
      # Runs several steps in the host while loop.
359
      steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
360

361
      if tf.test.is_built_with_cuda():
362
363
        # TODO(zongweiz): merge with train_steps once tf.while_loop
        # GPU performance bugs are fixed.
364
365
        for _ in range(steps):
          train_single_step(train_iterator)
366
367
368
369
      else:
        # Converts steps to a Tensor to avoid tf.function retracing.
        train_steps(train_iterator,
                    tf.convert_to_tensor(steps, dtype=tf.int32))
370
      train_loss = _float_metric_value(train_loss_metric)
371
      current_step += steps
372
      _run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

      # Updates training logging.
      training_status = 'Train Step: %d/%d  / loss = %s' % (
          current_step, total_training_steps, train_loss)

      if train_summary_writer:
        with train_summary_writer.as_default():
          tf.summary.scalar(
              train_loss_metric.name, train_loss, step=current_step)
          for metric in train_metrics + model.metrics:
            metric_value = _float_metric_value(metric)
            training_status += '  %s = %f' % (metric.name, metric_value)
            tf.summary.scalar(metric.name, metric_value, step=current_step)
          train_summary_writer.flush()
      logging.info(training_status)

      # Saves model checkpoints and run validation steps at every epoch end.
      if current_step % steps_per_epoch == 0:
        # To avoid repeated model saving, we do not save after the last
        # step of training.
        if current_step < total_training_steps:
          _save_checkpoint(checkpoint, model_dir,
                           checkpoint_name.format(step=current_step))
Chen Chen's avatar
Chen Chen committed
396
397
398
399
          if sub_model_export_name:
            _save_checkpoint(
                sub_model_checkpoint, model_dir,
                '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
400
401
402
403
404
405
406
        if eval_input_fn:
          logging.info('Running evaluation after step: %s.', current_step)
          _run_evaluation(current_step,
                          _get_input_iterator(eval_input_fn, strategy))
          # Re-initialize evaluation metric.
          for metric in eval_metrics + model.metrics:
            metric.reset_states()
407

408
409
    _save_checkpoint(checkpoint, model_dir,
                     checkpoint_name.format(step=current_step))
Chen Chen's avatar
Chen Chen committed
410
411
412
    if sub_model_export_name:
      _save_checkpoint(sub_model_checkpoint, model_dir,
                       '%s.ckpt' % sub_model_export_name)
413

414
415
416
417
    if eval_input_fn:
      logging.info('Running final evaluation after training is complete.')
      _run_evaluation(current_step,
                      _get_input_iterator(eval_input_fn, strategy))
418

419
420
421
422
423
424
425
426
427
    training_summary = {
        'total_training_steps': total_training_steps,
        'train_loss': _float_metric_value(train_loss_metric),
    }
    if eval_metrics:
      # TODO(hongkuny): Cleans up summary reporting in text.
      training_summary['last_train_metrics'] = _float_metric_value(
          train_metrics[0])
      training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
428

429
    write_txt_summary(training_summary, summary_dir)
430

431
    return model