standard_runner_test.py 3.01 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.standard_runner."""

from orbit import standard_runner
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
18
from orbit import utils
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

import tensorflow as tf


def dataset_fn(input_context=None):
  del input_context

  def dummy_data(_):
    return tf.zeros((1, 1), dtype=tf.float32)

  dataset = tf.data.Dataset.range(1)
  dataset = dataset.repeat()
  dataset = dataset.map(
      dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return dataset


Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
36
37
class TestTrainer(standard_runner.StandardTrainer):
  """A StandardTrainer subclass for tests."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
38

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
39
  def __init__(self, options=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
40
    self.strategy = tf.distribute.get_strategy()
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
41
42
43
44
    self.global_step = utils.create_global_step()
    distribute = self.strategy.experimental_distribute_datasets_from_function
    dataset = distribute(dataset_fn)
    super().__init__(train_dataset=dataset, options=options)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
45
46

  def train_loop_begin(self):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
47
    self.global_step.assign(0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
48
49
50

  def train_step(self, iterator):

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
51
    def replica_step(_):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
52
53
      self.global_step.assign_add(1)

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
54
    self.strategy.run(replica_step, args=(next(iterator),))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
55
56
57
58

  def train_loop_end(self):
    return self.global_step.numpy()

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
59
60
61
62
63
64
65
66
67
68
69

class TestEvaluator(standard_runner.StandardEvaluator):
  """A StandardEvaluator subclass for tests."""

  def __init__(self, options=None):
    self.strategy = tf.distribute.get_strategy()
    self.global_step = utils.create_global_step()
    distribute = self.strategy.experimental_distribute_datasets_from_function
    dataset = distribute(dataset_fn)
    super().__init__(eval_dataset=dataset, options=options)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
  def eval_begin(self):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
71
    self.global_step.assign(0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
72
73
74

  def eval_step(self, iterator):

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
75
    def replica_step(_):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
77
      self.global_step.assign_add(1)

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
78
    self.strategy.run(replica_step, args=(next(iterator),))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
80
81
82
83
84
85

  def eval_end(self):
    return self.global_step.numpy()


class StandardRunnerTest(tf.test.TestCase):

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
86
87
88
89
90
91
92
93
94
  def test_default_trainer(self):
    trainer = TestTrainer()
    self.assertEqual(trainer.train(tf.constant(10)), 10)

  def test_trainer_with_tpu_summary_optimization(self):
    options = standard_runner.StandardTrainerOptions(
        use_tpu_summary_optimization=True)
    trainer = TestTrainer(options)
    self.assertEqual(trainer.train(tf.constant(10)), 10)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
96
97
98
  def test_default_evaluator(self):
    evaluator = TestEvaluator()
    self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
99
100
101
102


if __name__ == '__main__':
  tf.test.main()