runner.py 3.5 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
15
"""Provides AbstractTrainer/Evaluator base classes, defining train/eval APIs."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17

import abc
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
18
19
20
21

from typing import Dict, Optional, Union

import numpy as np
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
23
24
import tensorflow as tf


Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
25
26
27
Output = Dict[str, Union[tf.Tensor, float, np.number, np.ndarray, 'Output']]  # pytype: disable=not-supported-yet


Hongkun Yu's avatar
Hongkun Yu committed
28
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
29
  """An abstract class defining the API required for training."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
30
31

  @abc.abstractmethod
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
32
33
34
  def train(self, num_steps: tf.Tensor) -> Optional[Output]:
    """Implements `num_steps` steps of training.

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
35
36
37
38
39
40
41
42
    This method will be called by the `Controller` to perform the "inner loop"
    of training. This inner loop amortizes the cost of bookkeeping associated
    with checkpointing, evaluation, and writing summaries. Additionally, the
    inner loop can be implemented (if desired) using TensorFlow's looping
    constructs (e.g. a `for` loop over a `tf.range` inside a `tf.function`),
    which can be necessary for getting optimal performance when running on TPU.
    For cases that don't require peak performance, a simple Python loop can be
    used instead for simplicity.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
43
44

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
45
46
47
      num_steps: The number of training steps to run. Note that it is up to the
        model what constitutes a "step", which may involve more than one update
        to model parameters (e.g., if training a GAN).
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
48
49

    Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
50
51
52
53
      Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
      If a dictionary is returned, it will be written to logs and as TensorBoard
      summaries. The dictionary may also be nested, which will generate a
      hierarchy of summary directories.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
54
55
56
57
    """
    pass


Hongkun Yu's avatar
Hongkun Yu committed
58
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
59
  """An abstract class defining the API required for evaluation."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
60
61

  @abc.abstractmethod
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
62
63
64
65
66
67
68
69
  def evaluate(self, num_steps: tf.Tensor) -> Optional[Output]:
    """Implements `num_steps` steps of evaluation.

    This method will by called the `Controller` to perform an evaluation. The
    `num_steps` parameter specifies the number of steps of evaluation to run,
    which is specified by the user when calling one of the `Controller`'s
    evaluation methods. A special sentinel value of `-1` is reserved to indicate
    evaluation should run until the underlying data source is exhausted.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
71

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
72
73
74
75
      num_steps: The number of evaluation steps to run. Note that it is up to
        the model what constitutes a "step". Evaluations may also want to
        support "complete" evaluations when `num_steps == -1`, running until a
        given data source is exhausted.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
77

    Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
78
79
80
81
      Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
      If a dictionary is returned, it will be written to logs and as TensorBoard
      summaries. The dictionary may also be nested, which will generate a
      hierarchy of summary directories.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
82
83
    """
    pass