resnet_imagenet_test.py 7.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the keras ResNet model with ImageNet data."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
import tensorflow as tf

from tensorflow.python.eager import context
from official.benchmark.models import resnet_imagenet_main
from official.utils.testing import integration
from official.vision.image_classification.resnet import imagenet_preprocessing


@parameterized.parameters(
    "resnet",
    # "resnet_polynomial_decay",  b/151854314
    "mobilenet",
    # "mobilenet_polynomial_decay"  b/151854314
)
class KerasImagenetTest(tf.test.TestCase):
  """Unit tests for Keras Models with ImageNet."""
  _default_flags_dict = [
Hongkun Yu's avatar
Hongkun Yu committed
39
40
41
42
43
44
45
46
      "-batch_size",
      "4",
      "-train_steps",
      "1",
      "-use_synthetic_data",
      "true",
      "-data_format",
      "channels_last",
47
48
49
  ]
  _extra_flags_dict = {
      "resnet": [
Hongkun Yu's avatar
Hongkun Yu committed
50
51
52
53
          "-model",
          "resnet50_v1.5",
          "-optimizer",
          "resnet50_default",
54
55
      ],
      "resnet_polynomial_decay": [
Hongkun Yu's avatar
Hongkun Yu committed
56
57
58
59
60
61
          "-model",
          "resnet50_v1.5",
          "-optimizer",
          "resnet50_default",
          "-pruning_method",
          "polynomial_decay",
62
63
      ],
      "mobilenet": [
Hongkun Yu's avatar
Hongkun Yu committed
64
65
66
67
          "-model",
          "mobilenet",
          "-optimizer",
          "mobilenet_default",
68
69
      ],
      "mobilenet_polynomial_decay": [
Hongkun Yu's avatar
Hongkun Yu committed
70
71
72
73
74
75
          "-model",
          "mobilenet",
          "-optimizer",
          "mobilenet_default",
          "-pruning_method",
          "polynomial_decay",
76
77
78
79
80
81
82
83
84
85
86
87
88
      ],
  }
  _tempdir = None

  @classmethod
  def setUpClass(cls):  # pylint: disable=invalid-name
    super(KerasImagenetTest, cls).setUpClass()
    resnet_imagenet_main.define_imagenet_keras_flags()

  def setUp(self):
    super(KerasImagenetTest, self).setUp()
    imagenet_preprocessing.NUM_IMAGES["validation"] = 4
    self.policy = \
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
89
        tf.keras.mixed_precision.experimental.global_policy()
90
91
92
93

  def tearDown(self):
    super(KerasImagenetTest, self).tearDown()
    tf.io.gfile.rmtree(self.get_temp_dir())
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
94
    tf.keras.mixed_precision.experimental.set_policy(self.policy)
95
96
97
98
99
100
101
102

  def get_extra_flags_dict(self, flags_key):
    return self._extra_flags_dict[flags_key] + self._default_flags_dict

  def test_end_to_end_no_dist_strat(self, flags_key):
    """Test Keras model with 1 GPU, no distribution strategy."""

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
103
104
        "-distribution_strategy",
        "off",
105
106
107
108
109
110
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
111
        extra_flags=extra_flags)
112
113
114
115

  def test_end_to_end_graph_no_dist_strat(self, flags_key):
    """Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
116
117
118
119
        "-enable_eager",
        "false",
        "-distribution_strategy",
        "off",
120
121
122
123
124
125
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
126
        extra_flags=extra_flags)
127
128
129
130
131
132

  def test_end_to_end_1_gpu(self, flags_key):
    """Test Keras model with 1 GPU."""

    if context.num_gpus() < 1:
      self.skipTest(
Hongkun Yu's avatar
Hongkun Yu committed
133
134
          "{} GPUs are not available for this test. {} GPUs are available"
          .format(1, context.num_gpus()))
135
136

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
137
138
139
140
141
142
        "-num_gpus",
        "1",
        "-distribution_strategy",
        "mirrored",
        "-enable_checkpoint_and_export",
        "1",
143
144
145
146
147
148
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
149
        extra_flags=extra_flags)
150
151
152
153
154
155
156
157
158
159

  def test_end_to_end_1_gpu_fp16(self, flags_key):
    """Test Keras model with 1 GPU and fp16."""

    if context.num_gpus() < 1:
      self.skipTest(
          "{} GPUs are not available for this test. {} GPUs are available"
          .format(1, context.num_gpus()))

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
160
161
162
163
164
165
        "-num_gpus",
        "1",
        "-dtype",
        "fp16",
        "-distribution_strategy",
        "mirrored",
166
167
168
169
170
171
172
173
174
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    if "polynomial_decay" in extra_flags:
      self.skipTest("Pruning with fp16 is not currently supported.")

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
175
        extra_flags=extra_flags)
176
177
178
179
180
181

  def test_end_to_end_2_gpu(self, flags_key):
    """Test Keras model with 2 GPUs."""

    if context.num_gpus() < 2:
      self.skipTest(
Hongkun Yu's avatar
Hongkun Yu committed
182
183
          "{} GPUs are not available for this test. {} GPUs are available"
          .format(2, context.num_gpus()))
184
185

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
186
187
188
189
        "-num_gpus",
        "2",
        "-distribution_strategy",
        "mirrored",
190
191
192
193
194
195
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
196
        extra_flags=extra_flags)
197
198
199
200
201
202

  def test_end_to_end_xla_2_gpu(self, flags_key):
    """Test Keras model with XLA and 2 GPUs."""

    if context.num_gpus() < 2:
      self.skipTest(
Hongkun Yu's avatar
Hongkun Yu committed
203
204
          "{} GPUs are not available for this test. {} GPUs are available"
          .format(2, context.num_gpus()))
205
206

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
207
208
209
210
211
212
        "-num_gpus",
        "2",
        "-enable_xla",
        "true",
        "-distribution_strategy",
        "mirrored",
213
214
215
216
217
218
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
219
        extra_flags=extra_flags)
220
221
222
223
224
225

  def test_end_to_end_2_gpu_fp16(self, flags_key):
    """Test Keras model with 2 GPUs and fp16."""

    if context.num_gpus() < 2:
      self.skipTest(
Hongkun Yu's avatar
Hongkun Yu committed
226
227
          "{} GPUs are not available for this test. {} GPUs are available"
          .format(2, context.num_gpus()))
228
229

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
230
231
232
233
234
235
        "-num_gpus",
        "2",
        "-dtype",
        "fp16",
        "-distribution_strategy",
        "mirrored",
236
237
238
239
240
241
242
243
244
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    if "polynomial_decay" in extra_flags:
      self.skipTest("Pruning with fp16 is not currently supported.")

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
245
        extra_flags=extra_flags)
246
247
248
249
250

  def test_end_to_end_xla_2_gpu_fp16(self, flags_key):
    """Test Keras model with XLA, 2 GPUs and fp16."""
    if context.num_gpus() < 2:
      self.skipTest(
Hongkun Yu's avatar
Hongkun Yu committed
251
252
          "{} GPUs are not available for this test. {} GPUs are available"
          .format(2, context.num_gpus()))
253
254

    extra_flags = [
Hongkun Yu's avatar
Hongkun Yu committed
255
256
257
258
259
260
261
262
        "-num_gpus",
        "2",
        "-dtype",
        "fp16",
        "-enable_xla",
        "true",
        "-distribution_strategy",
        "mirrored",
263
264
265
266
267
268
269
270
271
    ]
    extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)

    if "polynomial_decay" in extra_flags:
      self.skipTest("Pruning with fp16 is not currently supported.")

    integration.run_synthetic(
        main=resnet_imagenet_main.run,
        tmp_root=self.get_temp_dir(),
Hongkun Yu's avatar
Hongkun Yu committed
272
        extra_flags=extra_flags)
273
274
275
276


if __name__ == "__main__":
  tf.test.main()