runner.py 3.1 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import abc
from typing import Dict, Optional, Text
import six
import tensorflow as tf


@six.add_metaclass(abc.ABCMeta)
class AbstractTrainer(tf.Module):
  """An abstract class defining the APIs required for training."""

  @abc.abstractmethod
  def train(self,
            num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
    """Implements model training with multiple steps.

    In training, it is common to break the total training steps into several
    training loops, so users can do checkpointing, write summaries and run some
    python callbacks. This is necessary for getting good performance in TPU
    training, as the overhead for launching a multi worker tf.function may be
    large in Eager mode. It is usually encouraged to create a host training loop
    (e.g. using a `tf.range` wrapping `strategy.run` inside a
    `tf.function`) in the TPU case. For the cases that don't require host
    training loop to acheive peak performance, users can just implement a simple
    python loop to drive each step.

    Args:
      num_steps: A guideline for how many training steps to run. Note that it is
        up to the model what constitutes a "step" (this may involve more than
        one update to model parameters, e.g. if training a GAN).

    Returns:
      The function may return a dictionary of `Tensors` or numpy arrays, which
      will be written to logs and as TensorBoard summaries.
    """
    pass


@six.add_metaclass(abc.ABCMeta)
class AbstractEvaluator(tf.Module):
  """An abstract class defining the APIs required for evaluation."""

  @abc.abstractmethod
  def evaluate(
      self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
    """Implements model evaluation.

    Args:
      num_steps: A guideline for how many evaluation steps to run. Note that it
        is up to the model what constitutes a "step". Generally, it may be
        desirable to support both a limited number of eval steps and iterating
        over a full dataset (however many steps are required) when `num_steps`
        is `None`.

    Returns:
      The function may return a dictionary of `Tensors` or numpy arrays, which
      will be written to logs and as TensorBoard summaries.
    """
    pass