Unverified Commit 8367cf6d authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

fix resnet breakage and add keras end-to-end tests (#6295)

* fix resnet breakage and add keras end-to-end tests

* delint

* address PR comments
parent 048e5bff
......@@ -30,6 +30,8 @@ from official.recommendation import data_pipeline
from official.recommendation import neumf_model
from official.recommendation import ncf_common
from official.recommendation import ncf_estimator_main
from official.recommendation import ncf_keras_main
from official.utils.testing import integration
NUM_TRAIN_NEG = 4
......@@ -183,22 +185,35 @@ class NcfTest(tf.test.TestCase):
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4)
_BASE_END_TO_END_FLAGS = ['-batch_size', '1024', '-train_epochs', '1']
_BASE_END_TO_END_FLAGS = {
"batch_size": 1024,
"train_epochs": 1,
"use_synthetic_data": True
}
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@flagsaver.flagsaver(**_BASE_END_TO_END_FLAGS)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end(self):
ncf_estimator_main.main(None)
def test_end_to_end_keras(self):
self.skipTest("TODO: fix synthetic data with keras")
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
@flagsaver.flagsaver(ml_perf=True, **_BASE_END_TO_END_FLAGS)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_mlperf(self):
ncf_estimator_main.main(None)
def test_end_to_end_keras_mlperf(self):
self.skipTest("TODO: fix synthetic data with keras")
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS +
['-ml_perf', 'True', '-distribution_strategy', 'off'])
if __name__ == "__main__":
......
......@@ -23,6 +23,8 @@ import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.testing import integration
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
......@@ -37,14 +39,23 @@ class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet.
"""
_num_validation_images = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
cifar10_main.define_cifar_flags()
keras_common.define_keras_flags()
def setUp(self):
super(BaseTest, self).setUp()
self._num_validation_images = cifar10_main.NUM_IMAGES['validation']
cifar10_main.NUM_IMAGES['validation'] = 4
def tearDown(self):
super(BaseTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
cifar10_main.NUM_IMAGES['validation'] = self._num_validation_images
def test_dataset_input_fn(self):
fake_data = bytearray()
......@@ -157,13 +168,20 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1']
extra_flags=['-resnet_version', '1', '-batch_size', '4']
)
def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2']
extra_flags=['-resnet_version', '2', '-batch_size', '4']
)
def test_cifar10_end_to_end_keras_synthetic_v1(self):
integration.run_synthetic(
main=keras_cifar_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-train_steps', '1']
)
......
......@@ -22,6 +22,8 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
from official.utils.testing import integration
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
......@@ -32,14 +34,23 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase):
_num_validation_images = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags()
def setUp(self):
super(BaseTest, self).setUp()
self._num_validation_images = imagenet_main.NUM_IMAGES['validation']
imagenet_main.NUM_IMAGES['validation'] = 4
def tearDown(self):
super(BaseTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
imagenet_main.NUM_IMAGES['validation'] = self._num_validation_images
def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model."""
......@@ -271,37 +282,48 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
extra_flags=['-resnet_version', '1', '-batch_size', '4']
)
def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
extra_flags=['-resnet_version', '2', '-batch_size', '4']
)
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-resnet_size', '18']
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '18']
)
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-resnet_size', '18']
extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '18']
)
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-resnet_size', '200']
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '200']
)
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-resnet_size', '200']
extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '200']
)
def test_imagenet_end_to_end_keras_synthetic_v1(self):
integration.run_synthetic(
main=keras_imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-train_steps', '1']
)
......
......@@ -150,7 +150,7 @@ def run(flags_obj):
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
strategy_scope = distribution_utils.MaybeDistributionScope(strategy)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
......
......@@ -144,7 +144,7 @@ def run(flags_obj):
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
strategy_scope = distribution_utils.MaybeDistributionScope(strategy)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
......
......@@ -92,10 +92,6 @@ def get_distribution_strategy(distribution_strategy="default",
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
if distribution_strategy == "collective":
return tf.contrib.distribute.CollectiveAllReduceStrategy(
num_gpus_per_worker=num_gpus)
raise ValueError(
"Unrecognized Distribution Strategy: %r" % distribution_strategy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment