Unverified Commit 092def7b authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Fix resnet tests (#7071)

parent 1636acc9
......@@ -24,6 +24,7 @@ import tensorflow as tf
from official.resnet import cifar10_main
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
......@@ -61,7 +62,7 @@ class KerasCifarTest(googletest.TestCase):
def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
......@@ -95,7 +96,7 @@ class KerasCifarTest(googletest.TestCase):
def test_end_to_end_1_gpu(self):
"""Test Keras model with 1 GPU."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 1:
......@@ -141,7 +142,7 @@ class KerasCifarTest(googletest.TestCase):
def test_end_to_end_2_gpu(self):
"""Test Keras model with 2 GPUs."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2:
......
......@@ -24,6 +24,7 @@ import tensorflow as tf
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
......@@ -61,7 +62,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
......@@ -95,7 +96,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_1_gpu(self):
"""Test Keras model with 1 GPU."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 1:
......@@ -141,7 +142,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_2_gpu(self):
"""Test Keras model with 2 GPUs."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2:
......@@ -164,7 +165,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_xla_2_gpu(self):
"""Test Keras model with XLA and 2 GPUs."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2:
......@@ -188,7 +189,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_2_gpu_fp16(self):
"""Test Keras model with 2 GPUs and fp16."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2:
......@@ -212,7 +213,7 @@ class KerasImagenetTest(googletest.TestCase):
def test_end_to_end_xla_2_gpu_fp16(self):
"""Test Keras model with XLA, 2 GPUs and fp16."""
config = keras_common.get_config_proto_v1()
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment