Unverified Commit 8367cf6d authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

fix resnet breakage and add keras end-to-end tests (#6295)

* fix resnet breakage and add keras end-to-end tests

* delint

* address PR comments
parent 048e5bff
...@@ -30,6 +30,8 @@ from official.recommendation import data_pipeline ...@@ -30,6 +30,8 @@ from official.recommendation import data_pipeline
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_estimator_main from official.recommendation import ncf_estimator_main
from official.recommendation import ncf_keras_main
from official.utils.testing import integration
NUM_TRAIN_NEG = 4 NUM_TRAIN_NEG = 4
...@@ -183,22 +185,35 @@ class NcfTest(tf.test.TestCase): ...@@ -183,22 +185,35 @@ class NcfTest(tf.test.TestCase):
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
_BASE_END_TO_END_FLAGS = ['-batch_size', '1024', '-train_epochs', '1']
_BASE_END_TO_END_FLAGS = { @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
"batch_size": 1024, def test_end_to_end_estimator(self):
"train_epochs": 1, integration.run_synthetic(
"use_synthetic_data": True ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
} extra_flags=self._BASE_END_TO_END_FLAGS)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@flagsaver.flagsaver(**_BASE_END_TO_END_FLAGS)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end(self): def test_end_to_end_keras(self):
ncf_estimator_main.main(None) self.skipTest("TODO: fix synthetic data with keras")
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
@flagsaver.flagsaver(ml_perf=True, **_BASE_END_TO_END_FLAGS)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_mlperf(self): def test_end_to_end_keras_mlperf(self):
ncf_estimator_main.main(None) self.skipTest("TODO: fix synthetic data with keras")
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS +
['-ml_perf', 'True', '-distribution_strategy', 'off'])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,6 +23,8 @@ import numpy as np ...@@ -23,6 +23,8 @@ import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main from official.resnet import cifar10_main
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.testing import integration from official.utils.testing import integration
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
...@@ -37,14 +39,23 @@ class BaseTest(tf.test.TestCase): ...@@ -37,14 +39,23 @@ class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet. """Tests for the Cifar10 version of Resnet.
""" """
_num_validation_images = None
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass() super(BaseTest, cls).setUpClass()
cifar10_main.define_cifar_flags() cifar10_main.define_cifar_flags()
keras_common.define_keras_flags()
def setUp(self):
super(BaseTest, self).setUp()
self._num_validation_images = cifar10_main.NUM_IMAGES['validation']
cifar10_main.NUM_IMAGES['validation'] = 4
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
cifar10_main.NUM_IMAGES['validation'] = self._num_validation_images
def test_dataset_input_fn(self): def test_dataset_input_fn(self):
fake_data = bytearray() fake_data = bytearray()
...@@ -157,13 +168,20 @@ class BaseTest(tf.test.TestCase): ...@@ -157,13 +168,20 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1'] extra_flags=['-resnet_version', '1', '-batch_size', '4']
) )
def test_cifar10_end_to_end_synthetic_v2(self): def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2'] extra_flags=['-resnet_version', '2', '-batch_size', '4']
)
def test_cifar10_end_to_end_keras_synthetic_v1(self):
integration.run_synthetic(
main=keras_cifar_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-train_steps', '1']
) )
......
...@@ -22,6 +22,8 @@ import unittest ...@@ -22,6 +22,8 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
from official.utils.testing import integration from official.utils.testing import integration
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
...@@ -32,14 +34,23 @@ _LABEL_CLASSES = 1001 ...@@ -32,14 +34,23 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
_num_validation_images = None
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass() super(BaseTest, cls).setUpClass()
imagenet_main.define_imagenet_flags() imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags()
def setUp(self):
super(BaseTest, self).setUp()
self._num_validation_images = imagenet_main.NUM_IMAGES['validation']
imagenet_main.NUM_IMAGES['validation'] = 4
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
imagenet_main.NUM_IMAGES['validation'] = self._num_validation_images
def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu): def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
...@@ -271,37 +282,48 @@ class BaseTest(tf.test.TestCase): ...@@ -271,37 +282,48 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1'] extra_flags=['-resnet_version', '1', '-batch_size', '4']
) )
def test_imagenet_end_to_end_synthetic_v2(self): def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2'] extra_flags=['-resnet_version', '2', '-batch_size', '4']
) )
def test_imagenet_end_to_end_synthetic_v1_tiny(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-resnet_size', '18'] extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '18']
) )
def test_imagenet_end_to_end_synthetic_v2_tiny(self): def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-resnet_size', '18'] extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '18']
) )
def test_imagenet_end_to_end_synthetic_v1_huge(self): def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-resnet_size', '200'] extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '200']
) )
def test_imagenet_end_to_end_synthetic_v2_huge(self): def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-resnet_size', '200'] extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '200']
)
def test_imagenet_end_to_end_keras_synthetic_v1(self):
integration.run_synthetic(
main=keras_imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-train_steps', '1']
) )
......
...@@ -150,7 +150,7 @@ def run(flags_obj): ...@@ -150,7 +150,7 @@ def run(flags_obj):
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
strategy_scope = distribution_utils.MaybeDistributionScope(strategy) strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
......
...@@ -144,7 +144,7 @@ def run(flags_obj): ...@@ -144,7 +144,7 @@ def run(flags_obj):
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
strategy_scope = distribution_utils.MaybeDistributionScope(strategy) strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
......
...@@ -92,10 +92,6 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -92,10 +92,6 @@ def get_distribution_strategy(distribution_strategy="default",
if distribution_strategy == "parameter_server": if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy() return tf.distribute.experimental.ParameterServerStrategy()
if distribution_strategy == "collective":
return tf.contrib.distribute.CollectiveAllReduceStrategy(
num_gpus_per_worker=num_gpus)
raise ValueError( raise ValueError(
"Unrecognized Distribution Strategy: %r" % distribution_strategy) "Unrecognized Distribution Strategy: %r" % distribution_strategy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment