Unverified Commit 49097655 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Unit tests pass TF 2.0 GPU and CPU locally. (#7101)

* Fix unit tests failures.

* 96% of TF 2.0 tests on GPU are passing.

* Currently all passing GPU and CPU TF 2.0

* Address code comments.

* use tf 2.0 cast.

* Comment about working on TF 2.0 CPU

* Uses contrib turn off for TF 2.0.

* Fix wide_deep and add keras_common_tests.

* use context to get num_gpus.

* Switch to tf.keras.metrics
parent 5175b7e6
...@@ -98,7 +98,7 @@ class BigQueryUploader(object): ...@@ -98,7 +98,7 @@ class BigQueryUploader(object):
this is a UUID4 format. this is a UUID4 format.
run_json_file: string, the file path that contains the run JSON data. run_json_file: string, the file path that contains the run JSON data.
""" """
with tf.gfile.GFile(run_json_file) as f: with tf.io.gfile.GFile(run_json_file) as f:
benchmark_json = json.load(f) benchmark_json = json.load(f)
self.upload_benchmark_run_json( self.upload_benchmark_run_json(
dataset_name, table_name, run_id, benchmark_json) dataset_name, table_name, run_id, benchmark_json)
...@@ -118,7 +118,7 @@ class BigQueryUploader(object): ...@@ -118,7 +118,7 @@ class BigQueryUploader(object):
metric_json_file: string, the file path that contains the metric JSON metric_json_file: string, the file path that contains the metric JSON
data. data.
""" """
with tf.gfile.GFile(metric_json_file) as f: with tf.io.gfile.GFile(metric_json_file) as f:
metrics = [] metrics = []
for line in f: for line in f:
metrics.append(json.loads(line.strip())) metrics.append(json.loads(line.strip()))
......
...@@ -61,7 +61,7 @@ class BigQueryUploaderTest(tf.test.TestCase): ...@@ -61,7 +61,7 @@ class BigQueryUploaderTest(tf.test.TestCase):
json.dump({"model_name": "value"}, f) json.dump({"model_name": "value"}, f)
def tearDown(self): def tearDown(self):
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
def test_upload_benchmark_run_json(self): def test_upload_benchmark_run_json(self):
self.benchmark_uploader.upload_benchmark_run_json( self.benchmark_uploader.upload_benchmark_run_json(
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import tempfile import tempfile
import unittest
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -26,11 +27,12 @@ import tensorflow as tf ...@@ -26,11 +27,12 @@ import tensorflow as tf
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.boosted_trees import train_higgs from official.boosted_trees import train_higgs
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
TEST_CSV = os.path.join(os.path.dirname(__file__), "train_higgs_test.csv") TEST_CSV = os.path.join(os.path.dirname(__file__), "train_higgs_test.csv")
tf.logging.set_verbosity(tf.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
...@@ -51,8 +53,9 @@ class BaseTest(tf.test.TestCase): ...@@ -51,8 +53,9 @@ class BaseTest(tf.test.TestCase):
# numpy.savez doesn't take gfile.Gfile, so need to write down and copy. # numpy.savez doesn't take gfile.Gfile, so need to write down and copy.
tmpfile = tempfile.NamedTemporaryFile() tmpfile = tempfile.NamedTemporaryFile()
np.savez_compressed(tmpfile, data=data) np.savez_compressed(tmpfile, data=data)
tf.gfile.Copy(tmpfile.name, self.input_npz) tf.io.gfile.copy(tmpfile.name, self.input_npz)
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_read_higgs_data(self): def test_read_higgs_data(self):
"""Tests read_higgs_data() function.""" """Tests read_higgs_data() function."""
# Error when a wrong data_dir is given. # Error when a wrong data_dir is given.
...@@ -68,6 +71,7 @@ class BaseTest(tf.test.TestCase): ...@@ -68,6 +71,7 @@ class BaseTest(tf.test.TestCase):
self.assertEqual((15, 29), train_data.shape) self.assertEqual((15, 29), train_data.shape)
self.assertEqual((5, 29), eval_data.shape) self.assertEqual((5, 29), eval_data.shape)
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_make_inputs_from_np_arrays(self): def test_make_inputs_from_np_arrays(self):
"""Tests make_inputs_from_np_arrays() function.""" """Tests make_inputs_from_np_arrays() function."""
train_data, _ = train_higgs.read_higgs_data( train_data, _ = train_higgs.read_higgs_data(
...@@ -115,6 +119,7 @@ class BaseTest(tf.test.TestCase): ...@@ -115,6 +119,7 @@ class BaseTest(tf.test.TestCase):
1.409523, -0.307865, 1.474605], 1.409523, -0.307865, 1.474605],
np.squeeze(features[feature_names[10]], 1)) np.squeeze(features[feature_names[10]], 1))
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_end_to_end(self): def test_end_to_end(self):
"""Tests end-to-end running.""" """Tests end-to-end running."""
model_dir = os.path.join(self.get_temp_dir(), "model") model_dir = os.path.join(self.get_temp_dir(), "model")
...@@ -131,6 +136,7 @@ class BaseTest(tf.test.TestCase): ...@@ -131,6 +136,7 @@ class BaseTest(tf.test.TestCase):
synth=False, max_train=None) synth=False, max_train=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint"))) self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_end_to_end_with_export(self): def test_end_to_end_with_export(self):
"""Tests end-to-end running.""" """Tests end-to-end running."""
model_dir = os.path.join(self.get_temp_dir(), "model") model_dir = os.path.join(self.get_temp_dir(), "model")
......
...@@ -33,6 +33,7 @@ import time ...@@ -33,6 +33,7 @@ import time
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python import eager as tfe
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.mnist import dataset as mnist_dataset from official.mnist import dataset as mnist_dataset
...@@ -41,8 +42,6 @@ from official.utils.flags import core as flags_core ...@@ -41,8 +42,6 @@ from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
tfe = tf.contrib.eager
def loss(logits, labels): def loss(logits, labels):
return tf.reduce_mean( return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits( tf.nn.sparse_softmax_cross_entropy_with_logits(
...@@ -83,13 +82,13 @@ def train(model, optimizer, dataset, step_counter, log_interval=None): ...@@ -83,13 +82,13 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
def test(model, dataset): def test(model, dataset):
"""Perform an evaluation of `model` on the examples from `dataset`.""" """Perform an evaluation of `model` on the examples from `dataset`."""
avg_loss = tfe.metrics.Mean('loss', dtype=tf.float32) avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
accuracy = tfe.metrics.Accuracy('accuracy', dtype=tf.float32) accuracy = tf.keras.metrics.Accuracy('accuracy', dtype=tf.float32)
for (images, labels) in dataset: for (images, labels) in dataset:
logits = model(images, training=False) logits = model(images, training=False)
avg_loss(loss(logits, labels)) avg_loss.update_state(loss(logits, labels))
accuracy( accuracy.update_state(
tf.argmax(logits, axis=1, output_type=tf.int64), tf.argmax(logits, axis=1, output_type=tf.int64),
tf.cast(labels, tf.int64)) tf.cast(labels, tf.int64))
print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' % print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
......
...@@ -17,8 +17,10 @@ from __future__ import absolute_import ...@@ -17,8 +17,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order from tensorflow.python import eager as tfe # pylint: disable=g-bad-import-order
from official.mnist import mnist from official.mnist import mnist
from official.mnist import mnist_eager from official.mnist import mnist_eager
...@@ -26,11 +28,11 @@ from official.utils.misc import keras_utils ...@@ -26,11 +28,11 @@ from official.utils.misc import keras_utils
def device(): def device():
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" return '/device:GPU:0' if tfe.context.num_gpus() else '/device:CPU:0'
def data_format(): def data_format():
return "channels_first" if tfe.num_gpus() else "channels_last" return 'channels_first' if tfe.context.num_gpus() else 'channels_last'
def random_dataset(): def random_dataset():
...@@ -43,7 +45,7 @@ def random_dataset(): ...@@ -43,7 +45,7 @@ def random_dataset():
def train(defun=False): def train(defun=False):
model = mnist.create_model(data_format()) model = mnist.create_model(data_format())
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tf.function(model.call)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
dataset = random_dataset() dataset = random_dataset()
with tf.device(device()): with tf.device(device()):
...@@ -55,31 +57,39 @@ def evaluate(defun=False): ...@@ -55,31 +57,39 @@ def evaluate(defun=False):
model = mnist.create_model(data_format()) model = mnist.create_model(data_format())
dataset = random_dataset() dataset = random_dataset()
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tf.function(model.call)
with tf.device(device()): with tf.device(device()):
mnist_eager.test(model, dataset) mnist_eager.test(model, dataset)
class MNISTTest(tf.test.TestCase): class MNISTTest(tf.test.TestCase):
"""Run tests for MNIST eager loop.""" """Run tests for MNIST eager loop.
MNIST eager uses contrib and will not work with TF 2.0. All tests are
disabled if using TF 2.0.
"""
def setUp(self): def setUp(self):
if not keras_utils.is_v2_0(): if not keras_utils.is_v2_0():
tf.compat.v1.enable_v2_behavior() tf.compat.v1.enable_v2_behavior()
super(MNISTTest, self).setUp() super(MNISTTest, self).setUp()
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_train(self): def test_train(self):
train(defun=False) train(defun=False)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_evaluate(self): def test_evaluate(self):
evaluate(defun=False) evaluate(defun=False)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_train_with_defun(self): def test_train_with_defun(self):
train(defun=True) train(defun=True)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_evaluate_with_defun(self): def test_evaluate_with_defun(self):
evaluate(defun=True) evaluate(defun=True)
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -18,10 +18,12 @@ from __future__ import division ...@@ -18,10 +18,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import time import time
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import mnist from official.mnist import mnist
from official.utils.misc import keras_utils
BATCH_SIZE = 100 BATCH_SIZE = 100
...@@ -43,8 +45,13 @@ def make_estimator(): ...@@ -43,8 +45,13 @@ def make_estimator():
class Tests(tf.test.TestCase): class Tests(tf.test.TestCase):
"""Run tests for MNIST model.""" """Run tests for MNIST model.
MNIST uses contrib and will not work with TF 2.0. All tests are disabled if
using TF 2.0.
"""
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_mnist(self): def test_mnist(self):
classifier = make_estimator() classifier = make_estimator()
classifier.train(input_fn=dummy_input_fn, steps=2) classifier.train(input_fn=dummy_input_fn, steps=2)
...@@ -64,6 +71,7 @@ class Tests(tf.test.TestCase): ...@@ -64,6 +71,7 @@ class Tests(tf.test.TestCase):
self.assertEqual(predictions['probabilities'].shape, (10,)) self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ()) self.assertEqual(predictions['classes'].shape, ())
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def mnist_model_fn_helper(self, mode, multi_gpu=False): def mnist_model_fn_helper(self, mode, multi_gpu=False):
features, labels = dummy_input_fn() features, labels = dummy_input_fn()
image_count = features.shape[0] image_count = features.shape[0]
...@@ -91,15 +99,19 @@ class Tests(tf.test.TestCase): ...@@ -91,15 +99,19 @@ class Tests(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_mnist_model_fn_train_mode(self): def test_mnist_model_fn_train_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN) self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_mnist_model_fn_train_mode_multi_gpu(self): def test_mnist_model_fn_train_mode_multi_gpu(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True) self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_mnist_model_fn_eval_mode(self): def test_mnist_model_fn_eval_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.EVAL) self.mnist_model_fn_helper(tf.estimator.ModeKeys.EVAL)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_mnist_model_fn_predict_mode(self): def test_mnist_model_fn_predict_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.PREDICT) self.mnist_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
...@@ -131,5 +143,5 @@ class Benchmarks(tf.test.Benchmark): ...@@ -131,5 +143,5 @@ class Benchmarks(tf.test.Benchmark):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.test.main() tf.test.main()
...@@ -31,6 +31,7 @@ from official.datasets import movielens ...@@ -31,6 +31,7 @@ from official.datasets import movielens
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import popen_helper from official.recommendation import popen_helper
from official.utils.misc import keras_utils
DATASET = "ml-test" DATASET = "ml-test"
...@@ -50,12 +51,16 @@ FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22" ...@@ -50,12 +51,16 @@ FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22"
def mock_download(*args, **kwargs): def mock_download(*args, **kwargs):
return return
# The forkpool used by data producers interacts badly with the threading # The forkpool used by data producers interacts badly with the threading
# used by TestCase. Without this patch tests will hang, and no amount # used by TestCase. Without this patch tests will hang, and no amount
# of diligent closing and joining within the producer will prevent it. # of diligent closing and joining within the producer will prevent it.
@mock.patch.object(popen_helper, "get_forkpool", popen_helper.get_fauxpool) @mock.patch.object(popen_helper, "get_forkpool", popen_helper.get_fauxpool)
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def setUp(self): def setUp(self):
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
self.temp_data_dir = self.get_temp_dir() self.temp_data_dir = self.get_temp_dir()
ratings_folder = os.path.join(self.temp_data_dir, DATASET) ratings_folder = os.path.join(self.temp_data_dir, DATASET)
tf.io.gfile.makedirs(ratings_folder) tf.io.gfile.makedirs(ratings_folder)
...@@ -119,7 +124,7 @@ class BaseTest(tf.test.TestCase): ...@@ -119,7 +124,7 @@ class BaseTest(tf.test.TestCase):
def drain_dataset(self, dataset, g): def drain_dataset(self, dataset, g):
# type: (tf.data.Dataset, tf.Graph) -> list # type: (tf.data.Dataset, tf.Graph) -> list
with self.test_session(graph=g) as sess: with self.session(graph=g) as sess:
with g.as_default(): with g.as_default():
batch = dataset.make_one_shot_iterator().get_next() batch = dataset.make_one_shot_iterator().get_next()
output = [] output = []
......
...@@ -19,19 +19,21 @@ from __future__ import division ...@@ -19,19 +19,21 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import mock import unittest
import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from absl.testing import flagsaver
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_estimator_main from official.recommendation import ncf_estimator_main
from official.recommendation import ncf_keras_main from official.recommendation import ncf_keras_main
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
...@@ -54,6 +56,7 @@ class NcfTest(tf.test.TestCase): ...@@ -54,6 +56,7 @@ class NcfTest(tf.test.TestCase):
rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
rconst.TOP_K = self.top_k_old rconst.TOP_K = self.top_k_old
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user, def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
top_k=rconst.TOP_K, match_mlperf=False): top_k=rconst.TOP_K, match_mlperf=False):
rconst.TOP_K = top_k rconst.TOP_K = top_k
...@@ -82,10 +85,10 @@ class NcfTest(tf.test.TestCase): ...@@ -82,10 +85,10 @@ class NcfTest(tf.test.TestCase):
hr = metric_ops[rconst.HR_KEY] hr = metric_ops[rconst.HR_KEY]
ndcg = metric_ops[rconst.NDCG_KEY] ndcg = metric_ops[rconst.NDCG_KEY]
init = [tf.global_variables_initializer(), init = [tf.compat.v1.global_variables_initializer(),
tf.local_variables_initializer()] tf.compat.v1.local_variables_initializer()]
with self.test_session(graph=g) as sess: with self.session(graph=g) as sess:
sess.run(init) sess.run(init)
return sess.run([hr[1], ndcg[1]]) return sess.run([hr[1], ndcg[1]])
...@@ -188,12 +191,14 @@ class NcfTest(tf.test.TestCase): ...@@ -188,12 +191,14 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1024', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1024', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self): def test_end_to_end_estimator(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS) extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self): def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic( integration.run_synthetic(
......
...@@ -23,6 +23,7 @@ import numpy as np ...@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main from official.resnet import cifar10_main
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
...@@ -42,6 +43,8 @@ class BaseTest(tf.test.TestCase): ...@@ -42,6 +43,8 @@ class BaseTest(tf.test.TestCase):
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass() super(BaseTest, cls).setUpClass()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
cifar10_main.define_cifar_flags() cifar10_main.define_cifar_flags()
def setUp(self): def setUp(self):
...@@ -76,7 +79,7 @@ class BaseTest(tf.test.TestCase): ...@@ -76,7 +79,7 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(label.shape, ()) self.assertAllEqual(label.shape, ())
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS)) self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
with self.test_session() as sess: with self.session() as sess:
image, label = sess.run([image, label]) image, label = sess.run([image, label])
self.assertEqual(label, 7) self.assertEqual(label, 7)
......
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
...@@ -41,6 +42,8 @@ class BaseTest(tf.test.TestCase): ...@@ -41,6 +42,8 @@ class BaseTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(BaseTest, self).setUp() super(BaseTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
self._num_validation_images = imagenet_main.NUM_IMAGES['validation'] self._num_validation_images = imagenet_main.NUM_IMAGES['validation']
imagenet_main.NUM_IMAGES['validation'] = 4 imagenet_main.NUM_IMAGES['validation'] = 4
......
...@@ -19,10 +19,10 @@ from __future__ import print_function ...@@ -19,10 +19,10 @@ from __future__ import print_function
from mock import Mock from mock import Mock
import numpy as np import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.platform import googletest
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
class KerasCommonTests(tf.test.TestCase): class KerasCommonTests(tf.test.TestCase):
...@@ -36,12 +36,13 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -36,12 +36,13 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy=.99988) history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990) eval_output = self._build_eval_output(.56432111, 5.990)
th = keras_common.TimeHistory(128, 100) th = keras_utils.TimeHistory(128, 100)
th.batch_start_timestamps = [1, 2, 3] th.timestamp_log = [keras_utils.BatchTimestamp(0, 1),
th.batch_end_timestamps = [4, 5, 6] keras_utils.BatchTimestamp(1, 2),
keras_utils.BatchTimestamp(2, 3)]
th.train_finish_time = 12345 th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, th) stats = keras_common.build_stats(history, eval_output, [th])
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
...@@ -49,8 +50,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -49,8 +50,7 @@ class KerasCommonTests(tf.test.TestCase):
self.assertEqual(.56432111, stats['accuracy_top_1']) self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss']) self.assertEqual(5.990, stats['eval_loss'])
self.assertItemsEqual([1, 2, 3], stats['batch_start_timestamps']) self.assertEqual(3, stats['step_timestamp_log'][2].timestamp)
self.assertItemsEqual([4, 5, 6], stats['batch_end_timestamps'])
self.assertEqual(12345, stats['train_finish_time']) self.assertEqual(12345, stats['train_finish_time'])
def test_build_stats_sparse(self): def test_build_stats_sparse(self):
...@@ -66,7 +66,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -66,7 +66,7 @@ class KerasCommonTests(tf.test.TestCase):
self.assertEqual(1.9844, stats['eval_loss']) self.assertEqual(1.9844, stats['eval_loss'])
def test_time_history(self): def test_time_history(self):
th = keras_common.TimeHistory(batch_size=128, log_steps=3) th = keras_utils.TimeHistory(batch_size=128, log_steps=3)
th.on_train_begin() th.on_train_begin()
th.on_batch_begin(0) th.on_batch_begin(0)
...@@ -85,15 +85,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -85,15 +85,7 @@ class KerasCommonTests(tf.test.TestCase):
th.on_batch_end(6) th.on_batch_end(6)
th.on_train_end() th.on_train_end()
self.assertEqual(3, len(th.batch_start_timestamps)) self.assertEqual(3, len(th.timestamp_log))
self.assertEqual(2, len(th.batch_end_timestamps))
self.assertEqual(0, th.batch_start_timestamps[0].batch_index)
self.assertEqual(1, th.batch_start_timestamps[1].batch_index)
self.assertEqual(4, th.batch_start_timestamps[2].batch_index)
self.assertEqual(3, th.batch_end_timestamps[0].batch_index)
self.assertEqual(6, th.batch_end_timestamps[1].batch_index)
def _build_history(self, loss, cat_accuracy=None, def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None): cat_accuracy_sparse=None):
...@@ -111,3 +103,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -111,3 +103,7 @@ class KerasCommonTests(tf.test.TestCase):
def _build_eval_output(self, top_1, eval_loss): def _build_eval_output(self, top_1, eval_loss):
eval_output = [np.float64(eval_loss), np.float64(top_1)] eval_output = [np.float64(eval_loss), np.float64(top_1)]
return eval_output return eval_output
if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
googletest.main()
...@@ -22,7 +22,6 @@ from tempfile import mkdtemp ...@@ -22,7 +22,6 @@ from tempfile import mkdtemp
import tensorflow as tf import tensorflow as tf
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
...@@ -279,5 +278,5 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -279,5 +278,5 @@ class KerasImagenetTest(googletest.TestCase):
) )
if __name__ == '__main__': if __name__ == "__main__":
googletest.main() googletest.main()
...@@ -32,9 +32,11 @@ from __future__ import division ...@@ -32,9 +32,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys import sys
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.misc import keras_utils
from official.utils.testing import reference_data from official.utils.testing import reference_data
...@@ -63,6 +65,11 @@ BLOCK_TESTS = [ ...@@ -63,6 +65,11 @@ BLOCK_TESTS = [
class BaseTest(reference_data.BaseTest): class BaseTest(reference_data.BaseTest):
"""Tests for core ResNet layers.""" """Tests for core ResNet layers."""
def setUp(self):
super(BaseTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
@property @property
def test_name(self): def test_name(self):
return "resnet" return "resnet"
...@@ -166,16 +173,37 @@ class BaseTest(reference_data.BaseTest): ...@@ -166,16 +173,37 @@ class BaseTest(reference_data.BaseTest):
correctness_function=self.default_correctness_function correctness_function=self.default_correctness_function
) )
@unittest.skipIf(tf.test.is_built_with_cuda(), "Results only match CPU.")
def test_batch_norm(self): def test_batch_norm(self):
"""Tests batch norm layer correctness.
Test fails on a GTX 1080 with the last value being significantly different:
7.629395e-05 (expected) -> -4.159546e-02 (actual). The tests passes on CPU
on TF 1.0 and TF 2.0.
"""
self._batch_norm_ops(test=True) self._batch_norm_ops(test=True)
def test_block_0(self): def test_block_0(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[0]) self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[0])
@unittest.skipIf(tf.test.is_built_with_cuda(), "Results only match CPU.")
def test_block_1(self): def test_block_1(self):
"""Test bottleneck=True, projection=False, resnet_version=1.
Test fails on a GTX 1080 but would pass with tolerances moved from
1e-06 to 1e-05. Being TF 1.0 and this was not setup as a GPU test originally
it makes sense to disable it on GPU vs. research.
"""
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[1]) self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[1])
@unittest.skipIf(tf.test.is_built_with_cuda(), "Results only match CPU.")
def test_block_2(self): def test_block_2(self):
"""Test bottleneck=True, projection=True, resnet_version=2, width=8.
Test fails on a GTX 1080 but would pass with tolerances moved from
1e-06 to 1e-05. Being TF 1.0 and this was not setup as a GPU test originally
it makes sense to disable it on GPU.
"""
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[2]) self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[2])
def test_block_3(self): def test_block_3(self):
......
...@@ -25,7 +25,7 @@ class ComputeBleuTest(tf.test.TestCase): ...@@ -25,7 +25,7 @@ class ComputeBleuTest(tf.test.TestCase):
def _create_temp_file(self, text): def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.gfile.Open(temp_file.name, 'w') as w: with tf.io.gfile.GFile(temp_file.name, "w") as w:
w.write(text) w.write(text)
return temp_file.name return temp_file.name
......
...@@ -25,15 +25,19 @@ from official.transformer.model import beam_search ...@@ -25,15 +25,19 @@ from official.transformer.model import beam_search
class BeamSearchHelperTests(tf.test.TestCase): class BeamSearchHelperTests(tf.test.TestCase):
def setUp(self):
super(BeamSearchHelperTests, self).setUp()
tf.compat.v1.disable_eager_execution()
def test_expand_to_beam_size(self): def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
x = beam_search._expand_to_beam_size(x, 3) x = beam_search._expand_to_beam_size(x, 3)
with self.test_session() as sess: with self.session() as sess:
shape = sess.run(tf.shape(x)) shape = sess.run(tf.shape(x))
self.assertAllEqual([7, 3, 4, 2, 5], shape) self.assertAllEqual([7, 3, 4, 2, 5], shape)
def test_shape_list(self): def test_shape_list(self):
y = tf.placeholder(dtype=tf.int32, shape=[]) y = tf.compat.v1.placeholder(dtype=tf.int32, shape=[])
x = tf.ones([7, y, 2, 5]) x = tf.ones([7, y, 2, 5])
shape = beam_search._shape_list(x) shape = beam_search._shape_list(x)
self.assertIsInstance(shape[0], int) self.assertIsInstance(shape[0], int)
...@@ -43,7 +47,7 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -43,7 +47,7 @@ class BeamSearchHelperTests(tf.test.TestCase):
def test_get_shape_keep_last_dim(self): def test_get_shape_keep_last_dim(self):
y = tf.constant(4.0) y = tf.constant(4.0)
x = tf.ones([7, tf.to_int32(tf.sqrt(y)), 2, 5]) x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
shape = beam_search._get_shape_keep_last_dim(x) shape = beam_search._get_shape_keep_last_dim(x)
self.assertAllEqual([None, None, None, 5], self.assertAllEqual([None, None, None, 5],
shape.as_list()) shape.as_list())
...@@ -51,14 +55,14 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -51,14 +55,14 @@ class BeamSearchHelperTests(tf.test.TestCase):
def test_flatten_beam_dim(self): def test_flatten_beam_dim(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
x = beam_search._flatten_beam_dim(x) x = beam_search._flatten_beam_dim(x)
with self.test_session() as sess: with self.session() as sess:
shape = sess.run(tf.shape(x)) shape = sess.run(tf.shape(x))
self.assertAllEqual([28, 2, 5], shape) self.assertAllEqual([28, 2, 5], shape)
def test_unflatten_beam_dim(self): def test_unflatten_beam_dim(self):
x = tf.ones([28, 2, 5]) x = tf.ones([28, 2, 5])
x = beam_search._unflatten_beam_dim(x, 7, 4) x = beam_search._unflatten_beam_dim(x, 7, 4)
with self.test_session() as sess: with self.session() as sess:
shape = sess.run(tf.shape(x)) shape = sess.run(tf.shape(x))
self.assertAllEqual([7, 4, 2, 5], shape) self.assertAllEqual([7, 4, 2, 5], shape)
...@@ -73,7 +77,7 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -73,7 +77,7 @@ class BeamSearchHelperTests(tf.test.TestCase):
# [20 21 22 23]]] # [20 21 22 23]]]
y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2) y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
with self.test_session() as sess: with self.session() as sess:
y = sess.run(y) y = sess.run(y)
self.assertAllEqual([[[4, 5, 6, 7], self.assertAllEqual([[[4, 5, 6, 7],
...@@ -87,7 +91,7 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -87,7 +91,7 @@ class BeamSearchHelperTests(tf.test.TestCase):
x_scores = [[0, 1, 1], [1, 0, 1]] x_scores = [[0, 1, 1], [1, 0, 1]]
y = beam_search._gather_topk_beams(x, x_scores, 2, 2) y = beam_search._gather_topk_beams(x, x_scores, 2, 2)
with self.test_session() as sess: with self.session() as sess:
y = sess.run(y) y = sess.run(y)
self.assertAllEqual([[[4, 5, 6, 7], self.assertAllEqual([[[4, 5, 6, 7],
......
...@@ -21,16 +21,22 @@ from __future__ import print_function ...@@ -21,16 +21,22 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils from official.transformer.model import model_utils
from official.utils.misc import keras_utils
NEG_INF = -1e9 NEG_INF = -1e9
class ModelUtilsTest(tf.test.TestCase): class ModelUtilsTest(tf.test.TestCase):
def setUp(self):
super(ModelUtilsTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def test_get_padding(self): def test_get_padding(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]]) x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
padding = model_utils.get_padding(x, padding_value=0) padding = model_utils.get_padding(x, padding_value=0)
with self.test_session() as sess: with self.session() as sess:
padding = sess.run(padding) padding = sess.run(padding)
self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]], self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
...@@ -41,7 +47,7 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -41,7 +47,7 @@ class ModelUtilsTest(tf.test.TestCase):
bias = model_utils.get_padding_bias(x) bias = model_utils.get_padding_bias(x)
bias_shape = tf.shape(bias) bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5]) flattened_bias = tf.reshape(bias, [3, 5])
with self.test_session() as sess: with self.session() as sess:
flattened_bias, bias_shape = sess.run((flattened_bias, bias_shape)) flattened_bias, bias_shape = sess.run((flattened_bias, bias_shape))
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0], self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
...@@ -53,7 +59,7 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -53,7 +59,7 @@ class ModelUtilsTest(tf.test.TestCase):
def test_get_decoder_self_attention_bias(self): def test_get_decoder_self_attention_bias(self):
length = 5 length = 5
bias = model_utils.get_decoder_self_attention_bias(length) bias = model_utils.get_decoder_self_attention_bias(length)
with self.test_session() as sess: with self.session() as sess:
bias = sess.run(bias) bias = sess.run(bias)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF], self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
......
...@@ -26,7 +26,7 @@ class SubtokenizerTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class SubtokenizerTest(tf.test.TestCase):
def _init_subtokenizer(self, vocab_list): def _init_subtokenizer(self, vocab_list):
temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.gfile.Open(temp_file.name, 'w') as w: with tf.io.gfile.GFile(temp_file.name, "w") as w:
for subtoken in vocab_list: for subtoken in vocab_list:
w.write("'%s'" % subtoken) w.write("'%s'" % subtoken)
w.write("\n") w.write("\n")
......
...@@ -28,6 +28,9 @@ import tensorflow as tf ...@@ -28,6 +28,9 @@ import tensorflow as tf
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as tm from official.transformer.v2 import transformer_main as tm
from official.utils.misc import keras_utils
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp' FIXED_TIMESTAMP = 'my_time_stamp'
...@@ -80,19 +83,26 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -80,19 +83,26 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_static_batch(self): def test_train_static_batch(self):
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_1_gpu_with_dist_strat(self): def test_train_1_gpu_with_dist_strat(self):
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_2_gpu(self): def test_train_2_gpu(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -100,7 +110,12 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -100,7 +110,12 @@ class TransformerTaskTest(tf.test.TestCase):
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_2_gpu_fp16(self): def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
......
...@@ -28,6 +28,7 @@ import tensorflow as tf ...@@ -28,6 +28,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order # pylint: enable=wrong-import-order
from official.utils.data import file_io from official.utils.data import file_io
from official.utils.misc import keras_utils
_RAW_ROW = "raw_row" _RAW_ROW = "raw_row"
...@@ -105,6 +106,11 @@ def fixed_core_count(cpu_count): ...@@ -105,6 +106,11 @@ def fixed_core_count(cpu_count):
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def setUp(self):
super(BaseTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def _test_sharding(self, row_count, cpu_count, expected): def _test_sharding(self, row_count, cpu_count, expected):
df = pd.DataFrame({_DUMMY_COL: list(range(row_count))}) df = pd.DataFrame({_DUMMY_COL: list(range(row_count))})
with fixed_core_count(cpu_count): with fixed_core_count(cpu_count):
...@@ -153,7 +159,7 @@ class BaseTest(tf.test.TestCase): ...@@ -153,7 +159,7 @@ class BaseTest(tf.test.TestCase):
buffer_path = file_io.write_to_temp_buffer( buffer_path = file_io.write_to_temp_buffer(
df, self.get_temp_dir(), [_RAW_ROW, _DUMMY_COL, _DUMMY_VEC_COL]) df, self.get_temp_dir(), [_RAW_ROW, _DUMMY_COL, _DUMMY_VEC_COL])
with self.test_session(graph=tf.Graph()) as sess: with self.session(graph=tf.Graph()) as sess:
dataset = tf.data.TFRecordDataset(buffer_path) dataset = tf.data.TFRecordDataset(buffer_path)
dataset = dataset.batch(1).map( dataset = dataset.batch(1).map(
lambda x: tf.io.parse_example(serialized=x, features=_FEATURE_MAP)) lambda x: tf.io.parse_example(serialized=x, features=_FEATURE_MAP))
......
...@@ -24,10 +24,16 @@ import unittest ...@@ -24,10 +24,16 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.misc import keras_utils
class BaseTest(unittest.TestCase): class BaseTest(unittest.TestCase):
def setUp(self):
super(BaseTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def test_raise_in_non_list_names(self): def test_raise_in_non_list_names(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
hooks_helper.get_train_hooks( hooks_helper.get_train_hooks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment