flags_test.py 2.97 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Taylor Robie's avatar
Taylor Robie committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest

18
19
from absl import flags
import tensorflow as tf
Taylor Robie's avatar
Taylor Robie committed
20

21
from official.utils.flags import core as flags_core  # pylint: disable=g-bad-import-order
Taylor Robie's avatar
Taylor Robie committed
22
23


24
def define_flags():
25
  flags_core.define_base(num_gpu=False)
26
27
28
  flags_core.define_performance()
  flags_core.define_image()
  flags_core.define_benchmark()
Taylor Robie's avatar
Taylor Robie committed
29
30
31
32


class BaseTester(unittest.TestCase):

33
34
35
36
37
  @classmethod
  def setUpClass(cls):
    super(BaseTester, cls).setUpClass()
    define_flags()

Taylor Robie's avatar
Taylor Robie committed
38
39
40
41
42
43
44
45
  def test_default_setting(self):
    """Test to ensure fields exist and defaults can be set.
    """

    defaults = dict(
        data_dir="dfgasf",
        model_dir="dfsdkjgbs",
        train_epochs=534,
46
        epochs_between_evals=15,
Taylor Robie's avatar
Taylor Robie committed
47
        batch_size=256,
48
        hooks=["LoggingTensorHook"],
Taylor Robie's avatar
Taylor Robie committed
49
50
        num_parallel_calls=18,
        inter_op_parallelism_threads=5,
51
        intra_op_parallelism_threads=10,
Taylor Robie's avatar
Taylor Robie committed
52
53
54
        data_format="channels_first"
    )

55
56
    flags_core.set_defaults(**defaults)
    flags_core.parse_flags()
Taylor Robie's avatar
Taylor Robie committed
57
58

    for key, value in defaults.items():
59
      assert flags.FLAGS.get_flag_value(name=key, default=None) == value
Taylor Robie's avatar
Taylor Robie committed
60

61
62
63
  def test_benchmark_setting(self):
    defaults = dict(
        hooks=["LoggingMetricHook"],
64
65
        benchmark_log_dir="/tmp/12345",
        gcp_project="project_abc",
66
67
    )

68
69
    flags_core.set_defaults(**defaults)
    flags_core.parse_flags()
70
71

    for key, value in defaults.items():
72
      assert flags.FLAGS.get_flag_value(name=key, default=None) == value
73

Taylor Robie's avatar
Taylor Robie committed
74
75
76
77
  def test_booleans(self):
    """Test to ensure boolean flags trigger as expected.
    """

78
    flags_core.parse_flags([__file__, "--use_synthetic_data"])
Taylor Robie's avatar
Taylor Robie committed
79

80
    assert flags.FLAGS.use_synthetic_data
Taylor Robie's avatar
Taylor Robie committed
81

82
83
84
  def test_parse_dtype_info(self):
    for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
                                            ["fp32", tf.float32, 1]]:
85
      flags_core.parse_flags([__file__, "--dtype", dtype_str])
86

87
88
      self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)
89

90
91
      flags_core.parse_flags(
          [__file__, "--dtype", dtype_str, "--loss_scale", "5"])
92

93
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
94
95

    with self.assertRaises(SystemExit):
96
      flags_core.parse_flags([__file__, "--dtype", "int8"])
97

Taylor Robie's avatar
Taylor Robie committed
98
99
100

if __name__ == "__main__":
  unittest.main()