mnist_test.py 2.52 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""Test the Keras MNIST model on GPU."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

from absl.testing import parameterized
import tensorflow as tf

from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.testing import integration
from official.vision.image_classification import mnist_main


Will Cromar's avatar
Will Cromar committed
32
33
34
mnist_main.define_mnist_flags()


35
36
37
38
def eager_strategy_combinations():
  return combinations.combine(
      distribution=[
          strategy_combinations.default_strategy,
Will Cromar's avatar
Will Cromar committed
39
          strategy_combinations.cloud_tpu_strategy,
40
          strategy_combinations.one_device_strategy_gpu,
Hongkun Yu's avatar
Hongkun Yu committed
41
      ],)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
  """Unit tests for sample Keras MNIST model."""
  _tempdir = None

  @classmethod
  def setUpClass(cls):  # pylint: disable=invalid-name
    super(KerasMnistTest, cls).setUpClass()

  def tearDown(self):
    super(KerasMnistTest, self).tearDown()
    tf.io.gfile.rmtree(self.get_temp_dir())

  @combinations.generate(eager_strategy_combinations())
  def test_end_to_end(self, distribution):
    """Test Keras MNIST model with `strategy`."""

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
61
62
        "-train_epochs",
        "1",
63
64
65
66
        # Let TFDS find the metadata folder automatically
        "--data_dir="
    ]

Will Cromar's avatar
Will Cromar committed
67
68
69
70
71
72
73
74
75
    dummy_data = (
        tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
        tf.range(10),
    )
    datasets = (
        tf.data.Dataset.from_tensor_slices(dummy_data),
        tf.data.Dataset.from_tensor_slices(dummy_data),
    )

Hongkun Yu's avatar
Hongkun Yu committed
76
77
78
79
    run = functools.partial(
        mnist_main.run,
        datasets_override=datasets,
        strategy_override=distribution)
Will Cromar's avatar
Will Cromar committed
80
81
82
83

    integration.run_synthetic(
        main=run,
        synth=False,
Will Cromar's avatar
Will Cromar committed
84
        tmp_root=self.create_tempdir().full_path,
Will Cromar's avatar
Will Cromar committed
85
        extra_flags=extra_flags)
86
87
88
89


if __name__ == "__main__":
  tf.test.main()