Unverified Commit 092def7b authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Fix resnet tests (#7071)

parent 1636acc9
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from official.resnet import cifar10_main from official.resnet import cifar10_main
from official.resnet.keras import keras_cifar_main from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
# pylint: disable=ungrouped-imports # pylint: disable=ungrouped-imports
from tensorflow.python.eager import context from tensorflow.python.eager import context
...@@ -61,7 +62,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -61,7 +62,7 @@ class KerasCifarTest(googletest.TestCase):
def test_end_to_end_no_dist_strat(self): def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy.""" """Test Keras model with 1 GPU, no distribution strategy."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [ extra_flags = [
...@@ -95,7 +96,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -95,7 +96,7 @@ class KerasCifarTest(googletest.TestCase):
def test_end_to_end_1_gpu(self): def test_end_to_end_1_gpu(self):
"""Test Keras model with 1 GPU.""" """Test Keras model with 1 GPU."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 1: if context.num_gpus() < 1:
...@@ -141,7 +142,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -141,7 +142,7 @@ class KerasCifarTest(googletest.TestCase):
def test_end_to_end_2_gpu(self): def test_end_to_end_2_gpu(self):
"""Test Keras model with 2 GPUs.""" """Test Keras model with 2 GPUs."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2: if context.num_gpus() < 2:
......
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
# pylint: disable=ungrouped-imports # pylint: disable=ungrouped-imports
from tensorflow.python.eager import context from tensorflow.python.eager import context
...@@ -61,7 +62,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -61,7 +62,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_no_dist_strat(self): def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy.""" """Test Keras model with 1 GPU, no distribution strategy."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [ extra_flags = [
...@@ -95,7 +96,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -95,7 +96,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_1_gpu(self): def test_end_to_end_1_gpu(self):
"""Test Keras model with 1 GPU.""" """Test Keras model with 1 GPU."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 1: if context.num_gpus() < 1:
...@@ -141,7 +142,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -141,7 +142,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_2_gpu(self): def test_end_to_end_2_gpu(self):
"""Test Keras model with 2 GPUs.""" """Test Keras model with 2 GPUs."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2: if context.num_gpus() < 2:
...@@ -164,7 +165,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -164,7 +165,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_xla_2_gpu(self): def test_end_to_end_xla_2_gpu(self):
"""Test Keras model with XLA and 2 GPUs.""" """Test Keras model with XLA and 2 GPUs."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2: if context.num_gpus() < 2:
...@@ -188,7 +189,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -188,7 +189,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_2_gpu_fp16(self): def test_end_to_end_2_gpu_fp16(self):
"""Test Keras model with 2 GPUs and fp16.""" """Test Keras model with 2 GPUs and fp16."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2: if context.num_gpus() < 2:
...@@ -212,7 +213,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -212,7 +213,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_xla_2_gpu_fp16(self): def test_end_to_end_xla_2_gpu_fp16(self):
"""Test Keras model with XLA, 2 GPUs and fp16.""" """Test Keras model with XLA, 2 GPUs and fp16."""
config = keras_common.get_config_proto_v1() config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2: if context.num_gpus() < 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment