mnist_eager_test.py 2.84 KB
Newer Older
Asim Shankar's avatar
Asim Shankar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
21
import unittest

Karmel Allison's avatar
Karmel Allison committed
22
import tensorflow as tf  # pylint: disable=g-bad-import-order
23
from tensorflow.python import eager as tfe  # pylint: disable=g-bad-import-order
Asim Shankar's avatar
Asim Shankar committed
24

25
26
from official.mnist import mnist
from official.mnist import mnist_eager
Toby Boyd's avatar
Toby Boyd committed
27
from official.utils.misc import keras_utils
Asim Shankar's avatar
Asim Shankar committed
28
29
30


def device():
31
  return '/device:GPU:0' if tfe.context.num_gpus() else '/device:CPU:0'
Asim Shankar's avatar
Asim Shankar committed
32
33
34


def data_format():
35
  return 'channels_first' if tfe.context.num_gpus() else 'channels_last'
Asim Shankar's avatar
Asim Shankar committed
36
37
38
39
40
41
42
43
44
45


def random_dataset():
  batch_size = 64
  images = tf.random_normal([batch_size, 784])
  labels = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
  return tf.data.Dataset.from_tensors((images, labels))


def train(defun=False):
46
  model = mnist.create_model(data_format())
Asim Shankar's avatar
Asim Shankar committed
47
  if defun:
48
    model.call = tf.function(model.call)
Asim Shankar's avatar
Asim Shankar committed
49
50
51
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  dataset = random_dataset()
  with tf.device(device()):
52
53
    mnist_eager.train(model, optimizer, dataset,
                      step_counter=tf.train.get_or_create_global_step())
Asim Shankar's avatar
Asim Shankar committed
54
55
56


def evaluate(defun=False):
57
  model = mnist.create_model(data_format())
Asim Shankar's avatar
Asim Shankar committed
58
59
  dataset = random_dataset()
  if defun:
60
    model.call = tf.function(model.call)
Asim Shankar's avatar
Asim Shankar committed
61
62
63
64
65
  with tf.device(device()):
    mnist_eager.test(model, dataset)


class MNISTTest(tf.test.TestCase):
66
67
68
69
70
  """Run tests for MNIST eager loop.

  MNIST eager uses contrib and will not work with TF 2.0.  All tests are
  disabled if using TF 2.0.
  """
Asim Shankar's avatar
Asim Shankar committed
71

Toby Boyd's avatar
Toby Boyd committed
72
73
74
75
76
  def setUp(self):
    if not keras_utils.is_v2_0():
      tf.compat.v1.enable_v2_behavior()
    super(MNISTTest, self).setUp()

77
  @unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
Asim Shankar's avatar
Asim Shankar committed
78
79
80
  def test_train(self):
    train(defun=False)

81
  @unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
Asim Shankar's avatar
Asim Shankar committed
82
83
84
  def test_evaluate(self):
    evaluate(defun=False)

85
  @unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
Asim Shankar's avatar
Asim Shankar committed
86
87
88
  def test_train_with_defun(self):
    train(defun=True)

89
  @unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
Asim Shankar's avatar
Asim Shankar committed
90
91
92
93
  def test_evaluate_with_defun(self):
    evaluate(defun=True)


94
if __name__ == '__main__':
Asim Shankar's avatar
Asim Shankar committed
95
  tf.test.main()