test_case.py 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""A convenience wrapper around tf.test.TestCase to test with TPU, TF1, TF2."""
16

pkulzc's avatar
pkulzc committed
17
18
19
20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import zip
21
import tensorflow.compat.v1 as tf
22
from tensorflow.python import tf2  # pylint: disable=import-outside-toplevel
23
from object_detection.utils import tf_version
24
25
if not tf2.enabled():
  from tensorflow.contrib import tpu as contrib_tpu  # pylint: disable=g-import-not-at-top, line-too-long
26
27
28

flags = tf.app.flags

29
flags.DEFINE_bool('tpu_test', False, 'Deprecated Flag.')
30
31
32
FLAGS = flags.FLAGS


33
34
class TestCase(tf.test.TestCase):
  """Base Test class to handle execution under {TF1.X, TF2.X} x {TPU, CPU}.
35

36
37
38
  This class determines the TF version and availability of TPUs to set up
  tests appropriately.
  """
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  def maybe_extract_single_output(self, outputs):
    if isinstance(outputs, list) or isinstance(outputs, tuple):
      if isinstance(outputs[0], tf.Tensor):
        outputs_np = [output.numpy() for output in outputs]
      else:
        outputs_np = outputs
      if len(outputs_np) == 1:
        return outputs_np[0]
      else:
        return outputs_np
    else:
      if isinstance(outputs, tf.Tensor):
        return outputs.numpy()
      else:
        return outputs

  def has_tpu(self):
    """Returns whether there are any logical TPU devices."""
    return bool(tf.config.experimental.list_logical_devices(device_type='TPU'))
59

60
61
  def is_tf2(self):
    """Returns whether TF2 is enabled."""
62
    return tf_version.is_tf2()
63

64
  def execute_tpu_tf1(self, compute_fn, inputs, graph=None):
65
    """Executes compute_fn on TPU with Tensorflow 1.X.
66
67

    Args:
68
69
70
71
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
72
73
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().
74
75

    Returns:
76
      A list of numpy arrays or a single numpy array.
77
    """
78
    with self.session(graph=(graph or tf.Graph())) as sess:
79
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
80
81
82
83
84
85
86
87
      def wrap_graph_fn(*args, **kwargs):
        results = compute_fn(*args, **kwargs)
        if (not (isinstance(results, dict) or isinstance(results, tf.Tensor))
            and hasattr(results, '__iter__')):
          results = list(results)
        return results
      tpu_computation = contrib_tpu.rewrite(wrap_graph_fn, placeholders)
      sess.run(contrib_tpu.initialize_system())
88
89
90
91
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(tpu_computation,
                                      feed_dict=dict(zip(placeholders, inputs)))
92
93
      sess.run(contrib_tpu.shutdown_system())
    return self.maybe_extract_single_output(materialized_results)
94

95
96
  def execute_tpu_tf2(self, compute_fn, inputs):
    """Executes compute_fn on TPU with Tensorflow 2.X.
97
98

    Args:
99
100
101
102
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
103
104

    Returns:
105
      A list of numpy arrays or a single numpy array.
106
    """
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
        topology, num_replicas=1)
    strategy = tf.distribute.experimental.TPUStrategy(
        resolver, device_assignment=device_assignment)

    @tf.function
    def run():
      tf_inputs = [tf.constant(input_t) for input_t in inputs]
      return strategy.run(compute_fn, args=tf_inputs)
    outputs = run()
    tf.tpu.experimental.shutdown_tpu_system()
    return self.maybe_extract_single_output(outputs)

123
  def execute_cpu_tf1(self, compute_fn, inputs, graph=None):
124
125
126
127
128
129
130
    """Executes compute_fn on CPU with Tensorflow 1.X.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
131
132
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().
133
134
135
136
137
138

    Returns:
      A list of numpy arrays or a single numpy array.
    """
    if self.is_tf2():
      raise ValueError('Required version Tenforflow 1.X is not available.')
139
    with self.session(graph=(graph or tf.Graph())) as sess:
140
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
141
142
143
144
      results = compute_fn(*placeholders)
      if (not (isinstance(results, dict) or isinstance(results, tf.Tensor)) and
          hasattr(results, '__iter__')):
        results = list(results)
145
146
147
148
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(results, feed_dict=dict(zip(placeholders,
                                                                  inputs)))
149
150
151
152
    return self.maybe_extract_single_output(materialized_results)

  def execute_cpu_tf2(self, compute_fn, inputs):
    """Executes compute_fn on CPU with Tensorflow 2.X.
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single numpy array.
    """
    if not self.is_tf2():
      raise ValueError('Required version TensorFlow 2.0 is not available.')
    @tf.function
    def run():
      tf_inputs = [tf.constant(input_t) for input_t in inputs]
      return compute_fn(*tf_inputs)
    return self.maybe_extract_single_output(run())

171
  def execute_cpu(self, compute_fn, inputs, graph=None):
172
173
174
175
176
177
178
179
180
181
    """Executes compute_fn on CPU.

    Depending on the underlying TensorFlow installation (build deps) runs in
    either TF 1.X or TF 2.X style.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
182
183
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().
184
185
186
187
188
189
190

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if self.is_tf2():
      return self.execute_cpu_tf2(compute_fn, inputs)
    else:
191
      return self.execute_cpu_tf1(compute_fn, inputs, graph)
192

193
  def execute_tpu(self, compute_fn, inputs, graph=None):
194
195
196
197
198
199
200
201
202
203
    """Executes compute_fn on TPU.

    Depending on the underlying TensorFlow installation (build deps) runs in
    either TF 1.X or TF 2.X style.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
204
205
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().
206
207
208
209
210
211
212
213
214

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if not self.has_tpu():
      raise ValueError('No TPU Device found.')
    if self.is_tf2():
      return self.execute_tpu_tf2(compute_fn, inputs)
    else:
215
      return self.execute_tpu_tf1(compute_fn, inputs, graph)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

  def execute_tf2(self, compute_fn, inputs):
    """Runs compute_fn with TensorFlow 2.0.

    Executes on TPU if available, otherwise executes on CPU.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if not self.is_tf2():
      raise ValueError('Required version TensorFlow 2.0 is not available.')
    if self.has_tpu():
      return self.execute_tpu_tf2(compute_fn, inputs)
    else:
      return self.execute_cpu_tf2(compute_fn, inputs)

238
  def execute_tf1(self, compute_fn, inputs, graph=None):
239
240
241
242
243
244
245
246
247
    """Runs compute_fn with TensorFlow 1.X.

    Executes on TPU if available, otherwise executes on CPU.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
248
249
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().
250
251
252
253
254
255
256

    Returns:
      A list of numpy arrays or a single tensor.
    """
    if self.is_tf2():
      raise ValueError('Required version Tenforflow 1.X is not available.')
    if self.has_tpu():
257
      return self.execute_tpu_tf1(compute_fn, inputs, graph)
258
    else:
259
      return self.execute_cpu_tf1(compute_fn, inputs, graph)
260

261
  def execute(self, compute_fn, inputs, graph=None):
262
    """Runs compute_fn with inputs and returns results.
263

264
265
    * Executes in either TF1.X or TF2.X style based on the TensorFlow version.
    * Executes on TPU if available, otherwise executes on CPU.
266
267

    Args:
268
269
270
271
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
272
273
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().
274
275

    Returns:
276
      A list of numpy arrays or a single tensor.
277
    """
278
279
280
281
282
    if self.has_tpu() and tf2.enabled():
      return self.execute_tpu_tf2(compute_fn, inputs)
    elif not self.has_tpu() and tf2.enabled():
      return self.execute_cpu_tf2(compute_fn, inputs)
    elif self.has_tpu() and not tf2.enabled():
283
      return self.execute_tpu_tf1(compute_fn, inputs, graph)
284
    else:
285
      return self.execute_cpu_tf1(compute_fn, inputs, graph)