test_case.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""A convenience wrapper around tf.test.TestCase to test with TPU, TF1, TF2."""
16

pkulzc's avatar
pkulzc committed
17
18
19
20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import zip
21
import tensorflow as tf
22
23
24
from tensorflow.python import tf2  # pylint: disable=import-outside-toplevel
if not tf2.enabled():
  from tensorflow.contrib import tpu as contrib_tpu  # pylint: disable=g-import-not-at-top, line-too-long
25
26
27

flags = tf.app.flags

28
flags.DEFINE_bool('tpu_test', False, 'Deprecated Flag.')
29
30
31
FLAGS = flags.FLAGS


32
33
class TestCase(tf.test.TestCase):
  """Base Test class to handle execution under {TF1.X, TF2.X} x {TPU, CPU}.
34

35
36
37
  This class determines the TF version and availability of TPUs to set up
  tests appropriately.
  """
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  def maybe_extract_single_output(self, outputs):
    if isinstance(outputs, list) or isinstance(outputs, tuple):
      if isinstance(outputs[0], tf.Tensor):
        outputs_np = [output.numpy() for output in outputs]
      else:
        outputs_np = outputs
      if len(outputs_np) == 1:
        return outputs_np[0]
      else:
        return outputs_np
    else:
      if isinstance(outputs, tf.Tensor):
        return outputs.numpy()
      else:
        return outputs

  def has_tpu(self):
    """Returns whether there are any logical TPU devices."""
    return bool(tf.config.experimental.list_logical_devices(device_type='TPU'))
58

59
60
61
62
63
64
  def is_tf2(self):
    """Returns whether TF2 is enabled."""
    return tf2.enabled()

  def execute_tpu_tf1(self, compute_fn, inputs):
    """Executes compute_fn on TPU with Tensorflow 1.X.
65
66

    Args:
67
68
69
70
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
71
72

    Returns:
73
      A list of numpy arrays or a single numpy array.
74
75
76
    """
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
77
78
79
80
81
82
83
84
      def wrap_graph_fn(*args, **kwargs):
        results = compute_fn(*args, **kwargs)
        if (not (isinstance(results, dict) or isinstance(results, tf.Tensor))
            and hasattr(results, '__iter__')):
          results = list(results)
        return results
      tpu_computation = contrib_tpu.rewrite(wrap_graph_fn, placeholders)
      sess.run(contrib_tpu.initialize_system())
85
86
87
88
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(tpu_computation,
                                      feed_dict=dict(zip(placeholders, inputs)))
89
90
      sess.run(contrib_tpu.shutdown_system())
    return self.maybe_extract_single_output(materialized_results)
91

92
93
  def execute_tpu_tf2(self, compute_fn, inputs):
    """Executes compute_fn on TPU with Tensorflow 2.X.
94
95

    Args:
96
97
98
99
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
100
101

    Returns:
102
      A list of numpy arrays or a single numpy array.
103
    """
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
        topology, num_replicas=1)
    strategy = tf.distribute.experimental.TPUStrategy(
        resolver, device_assignment=device_assignment)

    @tf.function
    def run():
      tf_inputs = [tf.constant(input_t) for input_t in inputs]
      return strategy.run(compute_fn, args=tf_inputs)
    outputs = run()
    tf.tpu.experimental.shutdown_tpu_system()
    return self.maybe_extract_single_output(outputs)

  def execute_cpu_tf1(self, compute_fn, inputs):
    """Executes compute_fn on CPU with Tensorflow 1.X.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single numpy array.
    """
    if self.is_tf2():
      raise ValueError('Required version Tenforflow 1.X is not available.')
134
135
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
136
137
138
139
      results = compute_fn(*placeholders)
      if (not (isinstance(results, dict) or isinstance(results, tf.Tensor)) and
          hasattr(results, '__iter__')):
        results = list(results)
140
141
142
143
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(results, feed_dict=dict(zip(placeholders,
                                                                  inputs)))
144
145
146
147
    return self.maybe_extract_single_output(materialized_results)

  def execute_cpu_tf2(self, compute_fn, inputs):
    """Executes compute_fn on CPU with Tensorflow 2.X.
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single numpy array.
    """
    if not self.is_tf2():
      raise ValueError('Required version TensorFlow 2.0 is not available.')
    @tf.function
    def run():
      tf_inputs = [tf.constant(input_t) for input_t in inputs]
      return compute_fn(*tf_inputs)
    return self.maybe_extract_single_output(run())

  def execute_cpu(self, compute_fn, inputs):
    """Executes compute_fn on CPU.

    Depending on the underlying TensorFlow installation (build deps) runs in
    either TF 1.X or TF 2.X style.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if self.is_tf2():
      return self.execute_cpu_tf2(compute_fn, inputs)
    else:
      return self.execute_cpu_tf1(compute_fn, inputs)

  def execute_tpu(self, compute_fn, inputs):
    """Executes compute_fn on TPU.

    Depending on the underlying TensorFlow installation (build deps) runs in
    either TF 1.X or TF 2.X style.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if not self.has_tpu():
      raise ValueError('No TPU Device found.')
    if self.is_tf2():
      return self.execute_tpu_tf2(compute_fn, inputs)
    else:
      return self.execute_tpu_tf1(compute_fn, inputs)

  def execute_tf2(self, compute_fn, inputs):
    """Runs compute_fn with TensorFlow 2.0.

    Executes on TPU if available, otherwise executes on CPU.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if not self.is_tf2():
      raise ValueError('Required version TensorFlow 2.0 is not available.')
    if self.has_tpu():
      return self.execute_tpu_tf2(compute_fn, inputs)
    else:
      return self.execute_cpu_tf2(compute_fn, inputs)

  def execute_tf1(self, compute_fn, inputs):
    """Runs compute_fn with TensorFlow 1.X.

    Executes on TPU if available, otherwise executes on CPU.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if self.is_tf2():
      raise ValueError('Required version Tenforflow 1.X is not available.')
    if self.has_tpu():
      return self.execute_tpu_tf1(compute_fn, inputs)
    else:
      return self.execute_cpu_tf1(compute_fn, inputs)
249

250
251
  def execute(self, compute_fn, inputs):
    """Runs compute_fn with inputs and returns results.
252

253
254
    * Executes in either TF1.X or TF2.X style based on the TensorFlow version.
    * Executes on TPU if available, otherwise executes on CPU.
255
256

    Args:
257
258
259
260
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
261
262

    Returns:
263
      A list of numpy arrays or a single tensor.
264
    """
265
266
267
268
269
270
    if self.has_tpu() and tf2.enabled():
      return self.execute_tpu_tf2(compute_fn, inputs)
    elif not self.has_tpu() and tf2.enabled():
      return self.execute_cpu_tf2(compute_fn, inputs)
    elif self.has_tpu() and not tf2.enabled():
      return self.execute_tpu_tf1(compute_fn, inputs)
271
    else:
272
      return self.execute_cpu_tf1(compute_fn, inputs)