distributed_executor.py 29.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom training loop for running TensorFlow 2.0 models."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import os

from absl import flags
from absl import logging
Allen Wang's avatar
Allen Wang committed
26
27

import numpy as np
28
29
30
import tensorflow as tf

# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
Yeqing Li's avatar
Yeqing Li committed
31
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
32
from official.modeling.hyperparams import params_dict
Yeqing Li's avatar
Yeqing Li committed
33
from official.utils import hyperparams_flags
34
from official.common import distribute_utils
Will Cromar's avatar
Will Cromar committed
35
from official.utils.misc import keras_utils
36
37
38

FLAGS = flags.FLAGS

Yeqing Li's avatar
Yeqing Li committed
39
40
strategy_flags_dict = hyperparams_flags.strategy_flags_dict
hparam_flags_dict = hyperparams_flags.hparam_flags_dict
41
42
43
44
45
46
47
48
49
50


def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
  """Saves model to model_dir with provided checkpoint prefix."""

  checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
  saved_path = checkpoint.save(checkpoint_path)
  logging.info('Saving model as TF checkpoint: %s', saved_path)


Yeqing Li's avatar
Yeqing Li committed
51
52
53
54
55
56
57
def _steps_to_run(current_step, total_steps, steps_per_loop):
  """Calculates steps to run on device."""
  if steps_per_loop <= 0:
    raise ValueError('steps_per_loop should be positive integer.')
  return min(total_steps - current_step, steps_per_loop)


58
59
60
61
def _no_metric():
  return None


Yeqing Li's avatar
Yeqing Li committed
62
63
64
65
def metrics_as_dict(metric):
  """Puts input metric(s) into a list.

  Args:
66
67
    metric: metric(s) to be put into the list. `metric` could be an object, a
      list, or a dict of tf.keras.metrics.Metric or has the `required_method`.
Yeqing Li's avatar
Yeqing Li committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

  Returns:
    A dictionary of valid metrics.
  """
  if isinstance(metric, tf.keras.metrics.Metric):
    metrics = {metric.name: metric}
  elif isinstance(metric, list):
    metrics = {m.name: m for m in metric}
  elif isinstance(metric, dict):
    metrics = metric
  elif not metric:
    return {}
  else:
    metrics = {'metric': metric}
  return metrics


def metric_results(metric):
  """Collects results from the given metric(s)."""
  metrics = metrics_as_dict(metric)
  metric_result = {
      name: m.result().numpy().astype(float) for name, m in metrics.items()
  }
  return metric_result


def reset_states(metric):
  """Resets states of the given metric(s)."""
  metrics = metrics_as_dict(metric)
  for m in metrics.values():
    m.reset_states()


101
102
103
104
class SummaryWriter(object):
  """Simple SummaryWriter for writing dictionary of metrics.

  Attributes:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
105
    writer: The tf.SummaryWriter.
106
107
108
109
110
111
112
113
114
  """

  def __init__(self, model_dir: Text, name: Text):
    """Inits SummaryWriter with paths.

    Arguments:
      model_dir: the model folder path.
      name: the summary subfolder name.
    """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
115
    self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
116
117
118
119
120
121
122
123
124
125
126
127
128

  def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
    """Write metrics to summary with the given writer.

    Args:
      metrics: a dictionary of metrics values. Prefer dictionary.
      step: integer. The training step.
    """
    if not isinstance(metrics, dict):
      # Support scalar metric without name.
      logging.warning('Warning: summary writer prefer metrics as dictionary.')
      metrics = {'metric': metrics}

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
    with self.writer.as_default():
130
131
      for k, v in metrics.items():
        tf.summary.scalar(k, v, step=step)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
      self.writer.flush()
133
134
135


class DistributedExecutor(object):
Hongkun Yu's avatar
Hongkun Yu committed
136
  """Interface to train and eval models with tf.distribute.Strategy."""
137

Hongkun Yu's avatar
Hongkun Yu committed
138
  def __init__(self, strategy, params, model_fn, loss_fn, is_multi_host=False):
Yeqing Li's avatar
Yeqing Li committed
139
140
141
142
143
144
145
146
147
148
149
150
    """Constructor.

    Args:
      strategy: an instance of tf.distribute.Strategy.
      params: Model configuration needed to run distribution strategy.
      model_fn: Keras model function. Signature:
        (params: ParamsDict) -> tf.keras.models.Model.
      loss_fn: loss function. Signature:
        (y_true: Tensor, y_pred: Tensor) -> Tensor
      is_multi_host: Set to True when using multi hosts for training, like multi
        worker GPU or TPU pod (slice). Otherwise, False.
    """
151
152
153
154
155
156
157

    self._params = params
    self._model_fn = model_fn
    self._loss_fn = loss_fn
    self._strategy = strategy
    self._checkpoint_name = 'ctl_step_{step}.ckpt'
    self._is_multi_host = is_multi_host
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
158
159
160
    self.train_summary_writer = None
    self.eval_summary_writer = None
    self.global_train_step = None
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

  @property
  def checkpoint_name(self):
    """Returns default checkpoint name."""
    return self._checkpoint_name

  @checkpoint_name.setter
  def checkpoint_name(self, name):
    """Sets default summary writer for the current thread."""
    self._checkpoint_name = name

  def loss_fn(self):
    return self._loss_fn()

  def model_fn(self, params):
    return self._model_fn(params)

  def _save_config(self, model_dir):
    """Save parameters to config files if model_dir is defined."""

    logging.info('Save config to model_dir %s.', model_dir)
    if model_dir:
      if not tf.io.gfile.exists(model_dir):
        tf.io.gfile.makedirs(model_dir)
      self._params.lock()
      params_dict.save_params_dict_to_yaml(self._params,
                                           model_dir + '/params.yaml')
    else:
      logging.warning('model_dir is empty, so skip the save config.')

  def _get_input_iterator(
Yeqing Li's avatar
Yeqing Li committed
192
      self, input_fn: Callable[..., tf.data.Dataset],
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
      strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
    """Returns distributed dataset iterator.

    Args:
      input_fn: (params: dict) -> tf.data.Dataset.
      strategy: an instance of tf.distribute.Strategy.

    Returns:
      An iterator that yields input tensors.
    """

    if input_fn is None:
      return None
    # When training with multiple TPU workers, datasets needs to be cloned
    # across workers. Since Dataset instance cannot be cloned in eager mode,
    # we instead pass callable that returns a dataset.
    if self._is_multi_host:
      return iter(
          strategy.experimental_distribute_datasets_from_function(input_fn))
    else:
Yeqing Li's avatar
Yeqing Li committed
213
      input_data = input_fn()
214
215
      return iter(strategy.experimental_distribute_dataset(input_data))

Yeqing Li's avatar
Yeqing Li committed
216
217
218
219
220
221
  def _create_replicated_step(self,
                              strategy,
                              model,
                              loss_fn,
                              optimizer,
                              metric=None):
Yeqing Li's avatar
Yeqing Li committed
222
223
224
225
226
227
228
229
230
231
232
233
    """Creates a single training step.

    Args:
      strategy: an instance of tf.distribute.Strategy.
      model: (Tensor, bool) -> Tensor. model function.
      loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
      optimizer: tf.keras.optimizers.Optimizer.
      metric: tf.keras.metrics.Metric subclass.

    Returns:
      The training step callable.
    """
Yeqing Li's avatar
Yeqing Li committed
234
    metrics = metrics_as_dict(metric)
Yeqing Li's avatar
Yeqing Li committed
235
236
237
238
239
240
241
242
243
244

    def _replicated_step(inputs):
      """Replicated training step."""
      inputs, labels = inputs

      with tf.GradientTape() as tape:
        outputs = model(inputs, training=True)
        prediction_loss = loss_fn(labels, outputs)
        loss = tf.reduce_mean(prediction_loss)
        loss = loss / strategy.num_replicas_in_sync
Yeqing Li's avatar
Yeqing Li committed
245
246
        for m in metrics.values():
          m.update_state(labels, outputs)
Yeqing Li's avatar
Yeqing Li committed
247
248
249
250
251
252
253

      grads = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      return loss

    return _replicated_step

254
255
256
257
258
259
260
261
  def _create_train_step(self,
                         strategy,
                         model,
                         loss_fn,
                         optimizer,
                         metric=None):
    """Creates a distributed training step.

Yeqing Li's avatar
Yeqing Li committed
262
263
264
265
266
267
    Args:
      strategy: an instance of tf.distribute.Strategy.
      model: (Tensor, bool) -> Tensor. model function.
      loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
      optimizer: tf.keras.optimizers.Optimizer.
      metric: tf.keras.metrics.Metric subclass.
268

Yeqing Li's avatar
Yeqing Li committed
269
270
    Returns:
      The training step callable.
271
    """
Yeqing Li's avatar
Yeqing Li committed
272
273
    replicated_step = self._create_replicated_step(strategy, model, loss_fn,
                                                   optimizer, metric)
274
275

    @tf.function
Yeqing Li's avatar
Yeqing Li committed
276
    def train_step(iterator, num_steps):
277
278
279
280
      """Performs a distributed training step.

      Args:
        iterator: an iterator that yields input tensors.
Yeqing Li's avatar
Yeqing Li committed
281
        num_steps: the number of steps in the loop.
282
283
284
285

      Returns:
        The loss tensor.
      """
Yeqing Li's avatar
Yeqing Li committed
286
287
288
      if not isinstance(num_steps, tf.Tensor):
        raise ValueError('steps should be an Tensor. Python object may cause '
                         'retracing.')
289

Hongkun Yu's avatar
Hongkun Yu committed
290
      per_replica_losses = strategy.run(replicated_step, args=(next(iterator),))
Yeqing Li's avatar
Yeqing Li committed
291
      for _ in tf.range(num_steps - 1):
Ken Franko's avatar
Ken Franko committed
292
        per_replica_losses = strategy.run(
Yeqing Li's avatar
Yeqing Li committed
293
            replicated_step, args=(next(iterator),))
294
295

      # For reporting, we returns the mean of losses.
Yeqing Li's avatar
Yeqing Li committed
296
297
298
299
      losses = tf.nest.map_structure(
          lambda x: strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None),
          per_replica_losses)
      return losses
300
301
302
303
304

    return train_step

  def _create_test_step(self, strategy, model, metric):
    """Creates a distributed test step."""
Yeqing Li's avatar
Yeqing Li committed
305
    metrics = metrics_as_dict(metric)
306
307
308
309
310
311
312
313
314
315
316
317

    @tf.function
    def test_step(iterator):
      """Calculates evaluation metrics on distributed devices."""
      if not metric:
        logging.info('Skip test_step because metric is None (%s)', metric)
        return None, None

      def _test_step_fn(inputs):
        """Replicated accuracy calculation."""
        inputs, labels = inputs
        model_outputs = model(inputs, training=False)
Yeqing Li's avatar
Yeqing Li committed
318
319
        for m in metrics.values():
          m.update_state(labels, model_outputs)
320
321
        return labels, model_outputs

Ken Franko's avatar
Ken Franko committed
322
      return strategy.run(_test_step_fn, args=(next(iterator),))
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

    return test_step

  def train(self,
            train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
            eval_input_fn: Callable[[params_dict.ParamsDict],
                                    tf.data.Dataset] = None,
            model_dir: Text = None,
            total_steps: int = 1,
            iterations_per_loop: int = 1,
            train_metric_fn: Callable[[], Any] = None,
            eval_metric_fn: Callable[[], Any] = None,
            summary_writer_fn: Callable[[Text, Text],
                                        SummaryWriter] = SummaryWriter,
            init_checkpoint: Callable[[tf.keras.Model], Any] = None,
Yeqing Li's avatar
Yeqing Li committed
338
            custom_callbacks: List[tf.keras.callbacks.Callback] = None,
339
            continuous_eval: bool = False,
340
341
342
343
344
345
346
            save_config: bool = True):
    """Runs distributed training.

    Args:
      train_input_fn: (params: dict) -> tf.data.Dataset training data input
        function.
      eval_input_fn: (Optional) same type as train_input_fn. If not None, will
347
348
        trigger evaluating metric on eval data. If None, will not run the eval
        step.
349
350
351
352
353
354
355
356
      model_dir: the folder path for model checkpoints.
      total_steps: total training steps.
      iterations_per_loop: train steps per loop. After each loop, this job will
        update metrics like loss and save checkpoint.
      train_metric_fn: metric_fn for evaluation in train_step.
      eval_metric_fn: metric_fn for evaluation in test_step.
      summary_writer_fn: function to create summary writer.
      init_checkpoint: function to load checkpoint.
Yeqing Li's avatar
Yeqing Li committed
357
358
359
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        methods are invoked during training.
360
361
362
      continuous_eval: If `True`, will continously run evaluation on every
        available checkpoints. If `False`, will do the evaluation once after the
        final step.
363
      save_config: bool. Whether to save params to model_dir.
Hongkun Yu's avatar
Hongkun Yu committed
364

365
    Returns:
366
      The training loss and eval metrics.
367
368
369
370
371
372
373
374
375
376
377
    """
    assert train_input_fn is not None
    if train_metric_fn and not callable(train_metric_fn):
      raise ValueError('if `train_metric_fn` is specified, '
                       'train_metric_fn must be a callable.')
    if eval_metric_fn and not callable(eval_metric_fn):
      raise ValueError('if `eval_metric_fn` is specified, '
                       'eval_metric_fn must be a callable.')
    train_metric_fn = train_metric_fn or _no_metric
    eval_metric_fn = eval_metric_fn or _no_metric

Yeqing Li's avatar
Yeqing Li committed
378
    if custom_callbacks and iterations_per_loop != 1:
Will Cromar's avatar
Will Cromar committed
379
      logging.warning(
Yeqing Li's avatar
Yeqing Li committed
380
381
382
          'It is sematically wrong to run callbacks when '
          'iterations_per_loop is not one (%s)', iterations_per_loop)

Will Cromar's avatar
Will Cromar committed
383
384
    custom_callbacks = custom_callbacks or []

Yeqing Li's avatar
Yeqing Li committed
385
386
387
388
389
    def _run_callbacks_on_batch_begin(batch):
      """Runs custom callbacks at the start of every step."""
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
Yeqing Li's avatar
Yeqing Li committed
390
391
        if callback:
          callback.on_batch_begin(batch)
Yeqing Li's avatar
Yeqing Li committed
392
393
394
395
396
397

    def _run_callbacks_on_batch_end(batch):
      """Runs custom callbacks at the end of every step."""
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
Yeqing Li's avatar
Yeqing Li committed
398
399
        if callback:
          callback.on_batch_end(batch)
Yeqing Li's avatar
Yeqing Li committed
400

401
402
403
    if save_config:
      self._save_config(model_dir)

404
405
406
407
408
    if FLAGS.save_checkpoint_freq:
      save_freq = FLAGS.save_checkpoint_freq
    else:
      save_freq = iterations_per_loop

409
410
411
412
413
    params = self._params
    strategy = self._strategy
    # To reduce unnecessary send/receive input pipeline operation, we place
    # input pipeline ops in worker task.
    train_iterator = self._get_input_iterator(train_input_fn, strategy)
414
    train_loss = None
415
    train_metric_result = None
416
    eval_metric_result = None
Yeqing Li's avatar
Yeqing Li committed
417
    tf.keras.backend.set_learning_phase(1)
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    with strategy.scope():
      # To correctly place the model weights on accelerators,
      # model and optimizer should be created in scope.
      model = self.model_fn(params.as_dict())
      if not hasattr(model, 'optimizer'):
        raise ValueError('User should set optimizer attribute to model '
                         'inside `model_fn`.')
      optimizer = model.optimizer

      # Training loop starts here.
      checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
      latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
      initial_step = 0
      if latest_checkpoint_file:
        logging.info(
            'Checkpoint file %s found and restoring from '
            'checkpoint', latest_checkpoint_file)
        checkpoint.restore(latest_checkpoint_file)
        initial_step = optimizer.iterations.numpy()
        logging.info('Loading from checkpoint file completed. Init step %d',
                     initial_step)
      elif init_checkpoint:
        logging.info('Restoring from init checkpoint function')
        init_checkpoint(model)
        logging.info('Loading from init checkpoint file completed')

      current_step = optimizer.iterations.numpy()
      checkpoint_name = self.checkpoint_name

      eval_metric = eval_metric_fn()
      train_metric = train_metric_fn()
      train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
450
451
      self.train_summary_writer = train_summary_writer.writer

452
      test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
453
      self.eval_summary_writer = test_summary_writer.writer
454

Will Cromar's avatar
Will Cromar committed
455
456
457
458
459
    # Use training summary writer in TimeHistory if it's in use
    for cb in custom_callbacks:
      if isinstance(cb, keras_utils.TimeHistory):
        cb.summary_writer = self.train_summary_writer

460
461
    # Continue training loop.
    train_step = self._create_train_step(
Yeqing Li's avatar
Yeqing Li committed
462
463
464
465
466
        strategy=strategy,
        model=model,
        loss_fn=self.loss_fn(),
        optimizer=optimizer,
        metric=train_metric)
467
468
    test_step = None
    if eval_input_fn and eval_metric:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
469
      self.global_train_step = model.optimizer.iterations
470
471
      test_step = self._create_test_step(strategy, model, metric=eval_metric)

472
    # Step-0 operations
Yeqing Li's avatar
Yeqing Li committed
473
    if current_step == 0 and not latest_checkpoint_file:
Hongkun Yu's avatar
Hongkun Yu committed
474
475
      _save_checkpoint(checkpoint, model_dir,
                       checkpoint_name.format(step=current_step))
476
477
    if test_step:
      eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
Hongkun Yu's avatar
Hongkun Yu committed
478
479
480
481
482
      eval_metric_result = self._run_evaluation(test_step, current_step,
                                                eval_metric, eval_iterator)
      logging.info('Step: %s evalation metric = %s.', current_step,
                   eval_metric_result)
      test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
Yeqing Li's avatar
Yeqing Li committed
483
      reset_states(eval_metric)
484

485
    logging.info('Training started')
486
    last_save_checkpoint_step = current_step
Yeqing Li's avatar
Yeqing Li committed
487
    while current_step < total_steps:
488

Yeqing Li's avatar
Yeqing Li committed
489
490
491
492
493
      num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
      _run_callbacks_on_batch_begin(current_step)
      train_loss = train_step(train_iterator,
                              tf.convert_to_tensor(num_steps, dtype=tf.int32))
      current_step += num_steps
494
495
496

      train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
                                         train_loss)
Will Cromar's avatar
Will Cromar committed
497
498

      _run_callbacks_on_batch_end(current_step - 1)
499
500
      if not isinstance(train_loss, dict):
        train_loss = {'total_loss': train_loss}
Yeqing Li's avatar
Yeqing Li committed
501
502
      if np.isnan(train_loss['total_loss']):
        raise ValueError('total loss is NaN.')
503
504

      if train_metric:
Yeqing Li's avatar
Yeqing Li committed
505
        train_metric_result = metric_results(train_metric)
506
507
508
509
510
511
512
513
514
        train_metric_result.update(train_loss)
      else:
        train_metric_result = train_loss
      if callable(optimizer.lr):
        train_metric_result.update(
            {'learning_rate': optimizer.lr(current_step).numpy()})
      else:
        train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
      logging.info('Train Step: %d/%d  / loss = %s / training metric = %s',
Hongkun Yu's avatar
Hongkun Yu committed
515
                   current_step, total_steps, train_loss, train_metric_result)
516
517
518
519

      train_summary_writer(
          metrics=train_metric_result, step=optimizer.iterations)

Yeqing Li's avatar
Yeqing Li committed
520
521
      # Saves model checkpoints and run validation steps at every
      # iterations_per_loop steps.
522
523
      # To avoid repeated model saving, we do not save after the last
      # step of training.
524
525
      if save_freq > 0 and current_step < total_steps and (
          current_step - last_save_checkpoint_step) >= save_freq:
526
527
        _save_checkpoint(checkpoint, model_dir,
                         checkpoint_name.format(step=current_step))
528
        last_save_checkpoint_step = current_step
529

530
      if continuous_eval and current_step < total_steps and test_step:
531
532
533
534
535
536
537
538
539
540
        eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
        eval_metric_result = self._run_evaluation(test_step, current_step,
                                                  eval_metric, eval_iterator)
        logging.info('Step: %s evalation metric = %s.', current_step,
                     eval_metric_result)
        test_summary_writer(
            metrics=eval_metric_result, step=optimizer.iterations)

      # Re-initialize evaluation metric, except the last step.
      if eval_metric and current_step < total_steps:
Yeqing Li's avatar
Yeqing Li committed
541
        reset_states(eval_metric)
542
      if train_metric and current_step < total_steps:
Yeqing Li's avatar
Yeqing Li committed
543
        reset_states(train_metric)
544
545

    # Reaches the end of training and saves the last checkpoint.
546
547
548
    if last_save_checkpoint_step < total_steps:
      _save_checkpoint(checkpoint, model_dir,
                       checkpoint_name.format(step=current_step))
549
550
551
552
553
554
555

    if test_step:
      logging.info('Running final evaluation after training is complete.')
      eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
      eval_metric_result = self._run_evaluation(test_step, current_step,
                                                eval_metric, eval_iterator)
      logging.info('Final evaluation metric = %s.', eval_metric_result)
Hongkun Yu's avatar
Hongkun Yu committed
556
      test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
557

Will Cromar's avatar
Will Cromar committed
558
559
560
    self.train_summary_writer.close()
    self.eval_summary_writer.close()

561
    return train_metric_result, eval_metric_result
562
563
564
565
566
567
568
569
570
571

  def _run_evaluation(self, test_step, current_training_step, metric,
                      test_iterator):
    """Runs validation steps and aggregate metrics."""
    if not test_iterator or not metric:
      logging.warning(
          'Both test_iterator (%s) and metrics (%s) must not be None.',
          test_iterator, metric)
      return None
    logging.info('Running evaluation after step: %s.', current_training_step)
Yeqing Li's avatar
Yeqing Li committed
572
    eval_step = 0
573
574
    while True:
      try:
Yeqing Li's avatar
Yeqing Li committed
575
576
577
        with tf.experimental.async_scope():
          test_step(test_iterator)
          eval_step += 1
578
      except (StopIteration, tf.errors.OutOfRangeError):
Yeqing Li's avatar
Yeqing Li committed
579
        tf.experimental.async_clear_error()
580
581
        break

Yeqing Li's avatar
Yeqing Li committed
582
    metric_result = metric_results(metric)
Yeqing Li's avatar
Yeqing Li committed
583
584
585
    logging.info('Total eval steps: [%d]', eval_step)
    logging.info('At training step: [%r] Validation metric = %r',
                 current_training_step, metric_result)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    return metric_result

  def evaluate_from_model_dir(
      self,
      model_dir: Text,
      eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
      eval_metric_fn: Callable[[], Any],
      total_steps: int = -1,
      eval_timeout: int = None,
      min_eval_interval: int = 180,
      summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
    """Runs distributed evaluation on model folder.

    Args:
Yeqing Li's avatar
Yeqing Li committed
600
      model_dir: the folder for storing model checkpoints.
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
      eval_input_fn: (Optional) same type as train_input_fn. If not None, will
        trigger evaluting metric on eval data. If None, will not run eval step.
      eval_metric_fn: metric_fn for evaluation in test_step.
      total_steps: total training steps. If the current step reaches the
        total_steps, the evaluation loop will stop.
      eval_timeout: The maximum number of seconds to wait between checkpoints.
        If left as None, then the process will wait indefinitely. Used by
        tf.train.checkpoints_iterator.
      min_eval_interval: The minimum number of seconds between yielding
        checkpoints. Used by tf.train.checkpoints_iterator.
      summary_writer_fn: function to create summary writer.

    Returns:
      Eval metrics dictionary of the last checkpoint.
    """

    if not model_dir:
      raise ValueError('model_dir must be set.')

    def terminate_eval():
      tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
                      eval_timeout)
      return True

    summary_writer = summary_writer_fn(model_dir, 'eval')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
626
    self.eval_summary_writer = summary_writer.writer
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653

    # Read checkpoints from the given model directory
    # until `eval_timeout` seconds elapses.
    for checkpoint_path in tf.train.checkpoints_iterator(
        model_dir,
        min_interval_secs=min_eval_interval,
        timeout=eval_timeout,
        timeout_fn=terminate_eval):
      eval_metric_result, current_step = self.evaluate_checkpoint(
          checkpoint_path=checkpoint_path,
          eval_input_fn=eval_input_fn,
          eval_metric_fn=eval_metric_fn,
          summary_writer=summary_writer)
      if total_steps > 0 and current_step >= total_steps:
        logging.info('Evaluation finished after training step %d', current_step)
        break
    return eval_metric_result

  def evaluate_checkpoint(self,
                          checkpoint_path: Text,
                          eval_input_fn: Callable[[params_dict.ParamsDict],
                                                  tf.data.Dataset],
                          eval_metric_fn: Callable[[], Any],
                          summary_writer: SummaryWriter = None):
    """Runs distributed evaluation on the one checkpoint.

    Args:
Yeqing Li's avatar
Yeqing Li committed
654
      checkpoint_path: the checkpoint to evaluate.
655
656
657
      eval_input_fn: (Optional) same type as train_input_fn. If not None, will
        trigger evaluting metric on eval data. If None, will not run eval step.
      eval_metric_fn: metric_fn for evaluation in test_step.
Yeqing Li's avatar
Yeqing Li committed
658
      summary_writer: function to create summary writer.
659
660
661
662
663
664
665
666

    Returns:
      Eval metrics dictionary of the last checkpoint.
    """
    if not callable(eval_metric_fn):
      raise ValueError('if `eval_metric_fn` is specified, '
                       'eval_metric_fn must be a callable.')

667
    old_phase = tf.keras.backend.learning_phase()
Yeqing Li's avatar
Yeqing Li committed
668
    tf.keras.backend.set_learning_phase(0)
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    params = self._params
    strategy = self._strategy
    # To reduce unnecessary send/receive input pipeline operation, we place
    # input pipeline ops in worker task.
    with strategy.scope():

      # To correctly place the model weights on accelerators,
      # model and optimizer should be created in scope.
      model = self.model_fn(params.as_dict())
      checkpoint = tf.train.Checkpoint(model=model)

      eval_metric = eval_metric_fn()
      assert eval_metric, 'eval_metric does not exist'
      test_step = self._create_test_step(strategy, model, metric=eval_metric)

      logging.info('Starting to evaluate.')
      if not checkpoint_path:
        raise ValueError('checkpoint path is empty')
      reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
      current_step = reader.get_tensor(
          'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
Hongkun Yu's avatar
Hongkun Yu committed
690
691
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', checkpoint_path)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
692
693
      status = checkpoint.restore(checkpoint_path)
      status.expect_partial().assert_existing_objects_matched()
694

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
695
      self.global_train_step = model.optimizer.iterations
696
697
698
699
700
701
      eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
      eval_metric_result = self._run_evaluation(test_step, current_step,
                                                eval_metric, eval_iterator)
      logging.info('Step: %s evalation metric = %s.', current_step,
                   eval_metric_result)
      summary_writer(metrics=eval_metric_result, step=current_step)
Yeqing Li's avatar
Yeqing Li committed
702
      reset_states(eval_metric)
703

704
    tf.keras.backend.set_learning_phase(old_phase)
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    return eval_metric_result, current_step

  def predict(self):
    return NotImplementedError('Unimplmented function.')


class ExecutorBuilder(object):
  """Builder of DistributedExecutor.

  Example 1: Builds an executor with supported Strategy.
    builder = ExecutorBuilder(
        strategy_type='tpu',
        strategy_config={'tpu': '/bns/xxx'})
    dist_executor = builder.build_executor(
        params=params,
        model_fn=my_model_fn,
        loss_fn=my_loss_fn,
        metric_fn=my_metric_fn)

  Example 2: Builds an executor with customized Strategy.
    builder = ExecutorBuilder()
    builder.strategy = <some customized Strategy>
    dist_executor = builder.build_executor(
        params=params,
        model_fn=my_model_fn,
        loss_fn=my_loss_fn,
        metric_fn=my_metric_fn)

  Example 3: Builds a customized executor with customized Strategy.
    class MyDistributedExecutor(DistributedExecutor):
      # implementation ...

    builder = ExecutorBuilder()
    builder.strategy = <some customized Strategy>
    dist_executor = builder.build_executor(
        class_ctor=MyDistributedExecutor,
        params=params,
        model_fn=my_model_fn,
        loss_fn=my_loss_fn,
        metric_fn=my_metric_fn)
  """

  def __init__(self, strategy_type=None, strategy_config=None):
748
749
    _ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
                                           strategy_config.task_index)
Yeqing Li's avatar
Yeqing Li committed
750
751
752
753
    """Constructor.

    Args:
      strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
754
        If None, the user is responsible to set the strategy before calling
Yeqing Li's avatar
Yeqing Li committed
755
756
757
758
        build_executor(...).
      strategy_config: necessary config for constructing the proper Strategy.
        Check strategy_flags_dict() for examples of the structure.
    """
759
    self._strategy = distribute_utils.get_distribution_strategy(
Yeqing Li's avatar
Yeqing Li committed
760
761
762
763
764
        distribution_strategy=strategy_type,
        num_gpus=strategy_config.num_gpus,
        all_reduce_alg=strategy_config.all_reduce_alg,
        num_packs=strategy_config.num_packs,
        tpu_address=strategy_config.tpu)
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807

  @property
  def strategy(self):
    """Returns default checkpoint name."""
    return self._strategy

  @strategy.setter
  def strategy(self, new_strategy):
    """Sets default summary writer for the current thread."""
    self._strategy = new_strategy

  def build_executor(self,
                     class_ctor=DistributedExecutor,
                     params=None,
                     model_fn=None,
                     loss_fn=None,
                     **kwargs):
    """Creates an executor according to strategy type.

    See doc string of the DistributedExecutor.__init__ for more information of
    the
    input arguments.

    Args:
      class_ctor: A constructor of executor (default: DistributedExecutor).
      params: ParamsDict, all the model parameters and runtime parameters.
      model_fn: Keras model function.
      loss_fn: loss function.
      **kwargs: other arguments to the executor constructor.

    Returns:
      An instance of DistributedExecutor or its subclass.
    """
    if self._strategy is None:
      raise ValueError('`strategy` should not be None. You need to specify '
                       '`strategy_type` in the builder contructor or directly '
                       'set the `strategy` property of the builder.')
    return class_ctor(
        strategy=self._strategy,
        params=params,
        model_fn=model_fn,
        loss_fn=loss_fn,
        **kwargs)