model_training_utils.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to train BERT models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import json
22
23
24
25
26
import os

from absl import logging
import tensorflow as tf

27
28
SUMMARY_TXT = 'training_summary.txt'

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

def get_primary_cpu_task(use_remote_tpu=False):
  """Returns primary CPU task to which input pipeline Ops are put."""

  # Remote Eager Borg job configures the TPU worker with job name 'worker'.
  return '/job:worker' if use_remote_tpu else ''


def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
  """Saves model to with provided checkpoint prefix."""

  checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
  saved_path = checkpoint.save(checkpoint_path)
  logging.info('Saving model as TF checkpoint: %s', saved_path)
  return


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def _get_input_iterator(input_fn, strategy):
  """Returns distributed dataset iterator."""

  # When training with TPU pods, datasets needs to be cloned across
  # workers. Since Dataset instance cannot be cloned in eager mode, we instead
  # pass callable that returns a dataset.
  input_data = input_fn()
  if callable(input_data):
    iterator = iter(
        strategy.experimental_distribute_datasets_from_function(input_data))
  else:
    iterator = iter(strategy.experimental_distribute_dataset(input_data))
  return iterator


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _float_metric_value(metric):
  """Gets the value of a float-value keras metric."""
  return metric.result().numpy().astype(float)


def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
  """Calculates steps to run on device."""
  if steps_per_loop <= 1:
    return steps_per_loop
  remainder_in_epoch = current_step % steps_per_epoch
  if remainder_in_epoch != 0:
    return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
  else:
    return steps_per_loop


77
78
79
80
81
82
83
84
85
86
def run_customized_training_loop(
    # pylint: disable=invalid-name
    _sentinel=None,
    # pylint: enable=invalid-name
    strategy=None,
    model_fn=None,
    loss_fn=None,
    model_dir=None,
    train_input_fn=None,
    steps_per_epoch=None,
87
    steps_per_loop=1,
88
89
90
91
92
    epochs=1,
    eval_input_fn=None,
    eval_steps=None,
    metric_fn=None,
    init_checkpoint=None,
93
94
    use_remote_tpu=False,
    custom_callbacks=None):
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
110
111
112
113
114
115
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
116
117
118
119
120
121
122
123
124
125
126
127
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
        as an optimization.
128
      custom_callbacks: A list of Keras Callbacks objects to run during
129
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
130
        methods are invoked during training.
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
  """

  if _sentinel is not None:
    raise ValueError('only call `run_customized_training_loop()` '
                     'with named arguments.')

  required_arguments = [
      strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
  ]
  if [arg for arg in required_arguments if arg is None]:
    raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
150
151
152
153
154
155
156
157
                     '`steps_per_loop` and `steps_per_epoch` are required '
                     'parameters.')
  if steps_per_loop > steps_per_epoch:
    logging.error(
        'steps_per_loop: %d is specified to be greater than '
        ' steps_per_epoch: %d, we will use steps_per_epoch as'
        ' steps_per_loop.', steps_per_loop, steps_per_epoch)
    steps_per_loop = steps_per_epoch
158
159
160
161
162
163
164
165
166
167
168
169
170
  assert tf.executing_eagerly()

  if eval_input_fn and (eval_steps is None or metric_fn is None):
    raise ValueError(
        '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
        'is not none.')
  if metric_fn and not callable(metric_fn):
    raise ValueError(
        'if `metric_fn` is specified, metric_fn must be a callable.')

  # To reduce unnecessary send/receive input pipeline operation, we place input
  # pipeline ops in worker task.
  with tf.device(get_primary_cpu_task(use_remote_tpu)):
171
172
    train_iterator = _get_input_iterator(train_input_fn, strategy)

173
174
175
176
177
178
179
180
181
182
183
184
    with strategy.scope():
      total_training_steps = steps_per_epoch * epochs

      # To correctly place the model weights on accelerators,
      # model and optimizer should be created in scope.
      model, sub_model = model_fn()
      if not hasattr(model, 'optimizer'):
        raise ValueError('User should set optimizer attribute to model '
                         'inside `model_fn`.')
      optimizer = model.optimizer

      if init_checkpoint:
185
186
187
188
189
190
        logging.info(
            'Checkpoint file %s found and restoring from '
            'initial checkpoint for core model.', init_checkpoint)
        checkpoint = tf.train.Checkpoint(model=sub_model)
        checkpoint.restore(init_checkpoint).assert_consumed()
        logging.info('Loading from checkpoint file completed')
191

192
193
194
      train_loss_metric = tf.keras.metrics.Mean(
          'training_loss', dtype=tf.float32)
      eval_metric = metric_fn() if metric_fn else None
195
196
197
      # If evaluation is required, make a copy of metric as it will be used by
      # both train and evaluation.
      train_metric = (
198
199
          eval_metric.__class__.from_config(eval_metric.get_config())
          if eval_metric else None)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

      @tf.function
      def train_step(iterator):
        """Performs a distributed training step."""

        def _replicated_step(inputs):
          """Replicated training step."""

          inputs, labels = inputs
          with tf.GradientTape() as tape:
            model_outputs = model(inputs)
            loss = loss_fn(labels, model_outputs)

          tvars = model.trainable_variables
          grads = tape.gradient(loss, tvars)
          optimizer.apply_gradients(zip(grads, tvars))
216
217
218
219
          # For reporting, the metric takes the mean of losses.
          train_loss_metric.update_state(loss)
          if train_metric:
            train_metric.update_state(labels, model_outputs)
220

221
        strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
222
223
224
225
226
227
228
229
230
231

      @tf.function
      def test_step(iterator):
        """Calculates evaluation metrics on distributed devices."""

        def _test_step_fn(inputs):
          """Replicated accuracy calculation."""

          inputs, labels = inputs
          model_outputs = model(inputs, training=False)
232
          eval_metric.update_state(labels, model_outputs)
233

234
        strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
235

236
      def _run_evaluation(current_training_step, test_iterator):
237
238
239
240
        """Runs validation steps and aggregate metrics."""
        for _ in range(eval_steps):
          test_step(test_iterator)
        logging.info('Step: [%d] Validation metric = %f', current_training_step,
241
                     _float_metric_value(eval_metric))
242

243
      def _run_callbacks_on_batch_begin(batch):
244
245
246
247
        """Runs custom callbacks at the start of every step."""
        if not custom_callbacks:
          return
        for callback in custom_callbacks:
248
          callback.on_batch_begin(batch)
249
250
251
252
253
254
255
256

      def _run_callbacks_on_batch_end(batch):
        """Runs custom callbacks at the end of every step."""
        if not custom_callbacks:
          return
        for callback in custom_callbacks:
          callback.on_batch_end(batch)

257
      # Training loop starts here.
258
      checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
259
260
261
262
263
      latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
      if latest_checkpoint_file:
        logging.info(
            'Checkpoint file %s found and restoring from '
            'checkpoint', latest_checkpoint_file)
264
        checkpoint.restore(latest_checkpoint_file)
265
266
267
268
269
270
        logging.info('Loading from checkpoint file completed')

      current_step = optimizer.iterations.numpy()
      checkpoint_name = 'ctl_step_{step}.ckpt'

      while current_step < total_training_steps:
271
272
273
        # Training loss/metric are taking average over steps inside micro
        # training loop. We reset the their values before each round.
        train_loss_metric.reset_states()
274
        if train_metric:
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
          train_metric.reset_states()

        state_step = current_step
        _run_callbacks_on_batch_begin(state_step)
        for _ in range(
            _steps_to_run(state_step, steps_per_epoch, steps_per_loop)):
          current_step += 1
          train_step(train_iterator)
        _run_callbacks_on_batch_end(state_step)

        # Updates training logging.
        training_status = 'Train Step: %d/%d  / loss = %s' % (
            current_step, total_training_steps,
            _float_metric_value(train_loss_metric))
        if train_metric:
          training_status += ' training metric = %s' % _float_metric_value(
              train_metric)
        logging.info(training_status)
293

294
295
296
297
298
299
300
301
302
303
        # Saves model checkpoints and run validation steps at every epoch end.
        if current_step % steps_per_epoch == 0:
          # To avoid repeated model saving, we do not save after the last
          # step of training.
          if current_step < total_training_steps:
            _save_checkpoint(checkpoint, model_dir,
                             checkpoint_name.format(step=current_step))

          if eval_input_fn:
            logging.info('Running evaluation after step: %s.', current_step)
304
305
            _run_evaluation(current_step,
                            _get_input_iterator(eval_input_fn, strategy))
306
307
            # Re-initialize evaluation metric.
            eval_metric.reset_states()
308
309
310
311
312
313

      _save_checkpoint(checkpoint, model_dir,
                       checkpoint_name.format(step=current_step))

      if eval_input_fn:
        logging.info('Running final evaluation after training is complete.')
314
315
        _run_evaluation(current_step,
                        _get_input_iterator(eval_input_fn, strategy))
316
317
318

      training_summary = {
          'total_training_steps': total_training_steps,
319
          'train_loss': _float_metric_value(train_loss_metric),
320
      }
321
322
323
324
      if eval_metric:
        training_summary['last_train_metrics'] = _float_metric_value(
            train_metric)
        training_summary['eval_metrics'] = _float_metric_value(eval_metric)
325

davidmochen's avatar
davidmochen committed
326
327
      summary_path = os.path.join(model_dir, SUMMARY_TXT)
      with tf.io.gfile.GFile(summary_path, 'wb') as f:
328
        logging.info('Training Summary: \n%s', str(training_summary))
davidmochen's avatar
davidmochen committed
329
        f.write(json.dumps(training_summary, indent=4))
330
331

      return model