Unverified Commit 49097655 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Unit tests pass TF 2.0 GPU and CPU locally. (#7101)

* Fix unit tests failures.

* 96% of TF 2.0 tests on GPU are passing.

* Currently all passing GPU and CPU TF 2.0

* Address code comments.

* use tf 2.0 cast.

* Comment about working on TF 2.0 CPU

* Uses contrib turn off for TF 2.0.

* Fix wide_deep and add keras_common_tests.

* use context to get num_gpus.

* Switch to tf.keras.metrics
parent 5175b7e6
...@@ -309,14 +309,10 @@ def _gather_run_info(model_name, dataset_name, run_params, test_id): ...@@ -309,14 +309,10 @@ def _gather_run_info(model_name, dataset_name, run_params, test_id):
"test_id": test_id, "test_id": test_id,
"run_date": datetime.datetime.utcnow().strftime( "run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)} _DATE_TIME_FORMAT_PATTERN)}
session_config = None
if "session_config" in run_params:
session_config = run_params["session_config"]
_collect_tensorflow_info(run_info) _collect_tensorflow_info(run_info)
_collect_tensorflow_environment_variables(run_info) _collect_tensorflow_environment_variables(run_info)
_collect_run_params(run_info, run_params) _collect_run_params(run_info, run_params)
_collect_cpu_info(run_info) _collect_cpu_info(run_info)
_collect_gpu_info(run_info, session_config)
_collect_memory_info(run_info) _collect_memory_info(run_info)
_collect_test_environment(run_info) _collect_test_environment(run_info)
return run_info return run_info
...@@ -391,24 +387,6 @@ def _collect_cpu_info(run_info): ...@@ -391,24 +387,6 @@ def _collect_cpu_info(run_info):
"'cpuinfo' not imported. CPU info will not be logged.") "'cpuinfo' not imported. CPU info will not be logged.")
def _collect_gpu_info(run_info, session_config=None):
"""Collect local GPU information by TF device library."""
gpu_info = {}
local_device_protos = device_lib.list_local_devices(session_config)
gpu_info["count"] = len([d for d in local_device_protos
if d.device_type == "GPU"])
# The device description usually is a JSON string, which contains the GPU
# model info, eg:
# "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0"
for d in local_device_protos:
if d.device_type == "GPU":
gpu_info["model"] = _parse_gpu_model(d.physical_device_desc)
# Assume all the GPU connected are same model
break
run_info["machine_config"]["gpu_info"] = gpu_info
def _collect_memory_info(run_info): def _collect_memory_info(run_info):
try: try:
# Note: psutil is not installed in the TensorFlow OSS tree. # Note: psutil is not installed in the TensorFlow OSS tree.
......
...@@ -34,6 +34,7 @@ try: ...@@ -34,6 +34,7 @@ try:
except ImportError: except ImportError:
bigquery = None bigquery = None
from official.utils.misc import keras_utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
...@@ -46,12 +47,12 @@ class BenchmarkLoggerTest(tf.test.TestCase): ...@@ -46,12 +47,12 @@ class BenchmarkLoggerTest(tf.test.TestCase):
flags_core.define_benchmark() flags_core.define_benchmark()
def test_get_default_benchmark_logger(self): def test_get_default_benchmark_logger(self):
with flagsaver.flagsaver(benchmark_logger_type='foo'): with flagsaver.flagsaver(benchmark_logger_type="foo"):
self.assertIsInstance(logger.get_benchmark_logger(), self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger) logger.BaseBenchmarkLogger)
def test_config_base_benchmark_logger(self): def test_config_base_benchmark_logger(self):
with flagsaver.flagsaver(benchmark_logger_type='BaseBenchmarkLogger'): with flagsaver.flagsaver(benchmark_logger_type="BaseBenchmarkLogger"):
logger.config_benchmark_logger() logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(), self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger) logger.BaseBenchmarkLogger)
...@@ -59,16 +60,16 @@ class BenchmarkLoggerTest(tf.test.TestCase): ...@@ -59,16 +60,16 @@ class BenchmarkLoggerTest(tf.test.TestCase):
def test_config_benchmark_file_logger(self): def test_config_benchmark_file_logger(self):
# Set the benchmark_log_dir first since the benchmark_logger_type will need # Set the benchmark_log_dir first since the benchmark_logger_type will need
# the value to be set when it does the validation. # the value to be set when it does the validation.
with flagsaver.flagsaver(benchmark_log_dir='/tmp'): with flagsaver.flagsaver(benchmark_log_dir="/tmp"):
with flagsaver.flagsaver(benchmark_logger_type='BenchmarkFileLogger'): with flagsaver.flagsaver(benchmark_logger_type="BenchmarkFileLogger"):
logger.config_benchmark_logger() logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(), self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkFileLogger) logger.BenchmarkFileLogger)
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.') @unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
@mock.patch.object(bigquery, "Client") @mock.patch.object(bigquery, "Client")
def test_config_benchmark_bigquery_logger(self, mock_bigquery_client): def test_config_benchmark_bigquery_logger(self, mock_bigquery_client):
with flagsaver.flagsaver(benchmark_logger_type='BenchmarkBigQueryLogger'): with flagsaver.flagsaver(benchmark_logger_type="BenchmarkBigQueryLogger"):
logger.config_benchmark_logger() logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(), self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkBigQueryLogger) logger.BenchmarkBigQueryLogger)
...@@ -261,9 +262,15 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -261,9 +262,15 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
{"name": "batch_size", "long_value": 32}) {"name": "batch_size", "long_value": 32})
self.assertEqual(run_info["run_parameters"][1], self.assertEqual(run_info["run_parameters"][1],
{"name": "dtype", "string_value": "fp16"}) {"name": "dtype", "string_value": "fp16"})
self.assertEqual(run_info["run_parameters"][2], if keras_utils.is_v2_0():
{"name": "random_tensor", "string_value": self.assertEqual(run_info["run_parameters"][2],
"Tensor(\"Const:0\", shape=(), dtype=float32)"}) {"name": "random_tensor", "string_value":
"tf.Tensor(2.0, shape=(), dtype=float32)"})
else:
self.assertEqual(run_info["run_parameters"][2],
{"name": "random_tensor", "string_value":
"Tensor(\"Const:0\", shape=(), dtype=float32)"})
self.assertEqual(run_info["run_parameters"][3], self.assertEqual(run_info["run_parameters"][3],
{"name": "resnet_size", "long_value": 50}) {"name": "resnet_size", "long_value": 50})
self.assertEqual(run_info["run_parameters"][4], self.assertEqual(run_info["run_parameters"][4],
...@@ -286,12 +293,6 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -286,12 +293,6 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
self.assertEqual(run_info["tensorflow_environment_variables"], self.assertEqual(run_info["tensorflow_environment_variables"],
expected_tf_envs) expected_tf_envs)
@unittest.skipUnless(tf.test.is_built_with_cuda(), "requires GPU")
def test_collect_gpu_info(self):
run_info = {"machine_config": {}}
logger._collect_gpu_info(run_info)
self.assertNotEqual(run_info["machine_config"]["gpu_info"], {})
def test_collect_memory_info(self): def test_collect_memory_info(self):
run_info = {"machine_config": {}} run_info = {"machine_config": {}}
logger._collect_memory_info(run_info) logger._collect_memory_info(run_info)
...@@ -299,7 +300,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -299,7 +300,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
self.assertIsNotNone(run_info["machine_config"]["memory_available"]) self.assertIsNotNone(run_info["machine_config"]["memory_available"])
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.') @unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
class BenchmarkBigQueryLoggerTest(tf.test.TestCase): class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" Tests for Model Helper functions.""" """Tests for Model Helper functions."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -20,12 +20,18 @@ from __future__ import print_function ...@@ -20,12 +20,18 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
class PastStopThresholdTest(tf.test.TestCase): class PastStopThresholdTest(tf.test.TestCase):
"""Tests for past_stop_threshold.""" """Tests for past_stop_threshold."""
def setUp(self):
super(PastStopThresholdTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def test_past_stop_threshold(self): def test_past_stop_threshold(self):
"""Tests for normal operating conditions.""" """Tests for normal operating conditions."""
self.assertTrue(model_helpers.past_stop_threshold(0.54, 1)) self.assertTrue(model_helpers.past_stop_threshold(0.54, 1))
...@@ -77,7 +83,7 @@ class SyntheticDataTest(tf.test.TestCase): ...@@ -77,7 +83,7 @@ class SyntheticDataTest(tf.test.TestCase):
label_value=456, label_value=456,
label_dtype=tf.int32)).get_next() label_dtype=tf.int32)).get_next()
with self.test_session() as sess: with self.session() as sess:
for n in range(5): for n in range(5):
inp, lab = sess.run((input_element, label_element)) inp, lab = sess.run((input_element, label_element))
self.assertAllClose(inp, [123., 123., 123., 123., 123.]) self.assertAllClose(inp, [123., 123., 123., 123., 123.])
...@@ -92,7 +98,7 @@ class SyntheticDataTest(tf.test.TestCase): ...@@ -92,7 +98,7 @@ class SyntheticDataTest(tf.test.TestCase):
element = tf.compat.v1.data.make_one_shot_iterator(d).get_next() element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
self.assertFalse(isinstance(element, tuple)) self.assertFalse(isinstance(element, tuple))
with self.test_session() as sess: with self.session() as sess:
inp = sess.run(element) inp = sess.run(element)
self.assertAllClose(inp, [43.5, 43.5, 43.5, 43.5]) self.assertAllClose(inp, [43.5, 43.5, 43.5, 43.5])
...@@ -110,7 +116,7 @@ class SyntheticDataTest(tf.test.TestCase): ...@@ -110,7 +116,7 @@ class SyntheticDataTest(tf.test.TestCase):
self.assertIn('d', element['b']) self.assertIn('d', element['b'])
self.assertNotIn('c', element) self.assertNotIn('c', element)
with self.test_session() as sess: with self.session() as sess:
inp = sess.run(element) inp = sess.run(element)
self.assertAllClose(inp['a'], [1.1, 1.1]) self.assertAllClose(inp['a'], [1.1, 1.1])
self.assertAllClose(inp['b']['c'], [1.1, 1.1, 1.1]) self.assertAllClose(inp['b']['c'], [1.1, 1.1, 1.1])
......
...@@ -177,7 +177,7 @@ class BaseTest(tf.test.TestCase): ...@@ -177,7 +177,7 @@ class BaseTest(tf.test.TestCase):
init = tf.compat.v1.global_variables_initializer() init = tf.compat.v1.global_variables_initializer()
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver()
with self.test_session(graph=graph) as sess: with self.session(graph=graph) as sess:
sess.run(init) sess.run(init)
saver.save(sess=sess, save_path=os.path.join(data_dir, self.ckpt_prefix)) saver.save(sess=sess, save_path=os.path.join(data_dir, self.ckpt_prefix))
...@@ -244,7 +244,7 @@ class BaseTest(tf.test.TestCase): ...@@ -244,7 +244,7 @@ class BaseTest(tf.test.TestCase):
tf.version.VERSION, tf.version.GIT_VERSION) tf.version.VERSION, tf.version.GIT_VERSION)
) )
with self.test_session(graph=graph) as sess: with self.session(graph=graph) as sess:
sess.run(init) sess.run(init)
try: try:
saver.restore(sess=sess, save_path=os.path.join( saver.restore(sess=sess, save_path=os.path.join(
......
...@@ -29,15 +29,20 @@ from __future__ import print_function ...@@ -29,15 +29,20 @@ from __future__ import print_function
import sys import sys
import unittest import unittest
import warnings
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.misc import keras_utils
from official.utils.testing import reference_data from official.utils.testing import reference_data
class GoldenBaseTest(reference_data.BaseTest): class GoldenBaseTest(reference_data.BaseTest):
"""Class to ensure that reference data testing runs properly.""" """Class to ensure that reference data testing runs properly."""
def setUp(self):
if keras_utils.is_v2_0():
tf.compat.v1.disable_eager_execution()
super(GoldenBaseTest, self).setUp()
@property @property
def test_name(self): def test_name(self):
return "reference_data_test" return "reference_data_test"
...@@ -75,7 +80,6 @@ class GoldenBaseTest(reference_data.BaseTest): ...@@ -75,7 +80,6 @@ class GoldenBaseTest(reference_data.BaseTest):
result = float(tensor_result[0, 0]) result = float(tensor_result[0, 0])
result = result + 0.1 if bad_function else result result = result + 0.1 if bad_function else result
return [result] return [result]
self._save_or_test_ops( self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[input_tensor], test=test, name=name, graph=g, ops_to_eval=[input_tensor], test=test,
correctness_function=correctness_function correctness_function=correctness_function
...@@ -106,6 +110,7 @@ class GoldenBaseTest(reference_data.BaseTest): ...@@ -106,6 +110,7 @@ class GoldenBaseTest(reference_data.BaseTest):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self._uniform_random_ops(test=True, wrong_name=True) self._uniform_random_ops(test=True, wrong_name=True)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO:(b/136010138) Fails on TF 2.0.")
def test_tensor_shape_error(self): def test_tensor_shape_error(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self._uniform_random_ops(test=True, wrong_shape=True) self._uniform_random_ops(test=True, wrong_shape=True)
......
...@@ -62,7 +62,12 @@ py_test() { ...@@ -62,7 +62,12 @@ py_test() {
for test_file in `find official/ -name '*test.py' -print` for test_file in `find official/ -name '*test.py' -print`
do do
echo "####=======Testing ${test_file}=======####" echo "####=======Testing ${test_file}=======####"
${PY_BINARY} "${test_file}" || exit_code=$? ${PY_BINARY} "${test_file}"
_exit_code=$?
if [[ $_exit_code != 0 ]]; then
exit_code=$_exit_code
echo "FAIL: ${test_file}"
fi
done done
return "${exit_code}" return "${exit_code}"
......
...@@ -18,15 +18,16 @@ from __future__ import division ...@@ -18,15 +18,16 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
from official.wide_deep import census_dataset from official.wide_deep import census_dataset
from official.wide_deep import census_main from official.wide_deep import census_main
from official.wide_deep import wide_deep_run_loop
tf.logging.set_verbosity(tf.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
TEST_INPUT = ('18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,' TEST_INPUT = ('18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
'Husband,zyx,wvu,34,56,78,tsr,<=50K') 'Husband,zyx,wvu,34,56,78,tsr,<=50K')
...@@ -59,17 +60,19 @@ class BaseTest(tf.test.TestCase): ...@@ -59,17 +60,19 @@ class BaseTest(tf.test.TestCase):
# Create temporary CSV file # Create temporary CSV file
self.temp_dir = self.get_temp_dir() self.temp_dir = self.get_temp_dir()
self.input_csv = os.path.join(self.temp_dir, 'test.csv') self.input_csv = os.path.join(self.temp_dir, 'test.csv')
with tf.gfile.Open(self.input_csv, 'w') as temp_csv: with tf.io.gfile.GFile(self.input_csv, 'w') as temp_csv:
temp_csv.write(TEST_INPUT) temp_csv.write(TEST_INPUT)
with tf.gfile.Open(TEST_CSV, "r") as temp_csv: with tf.io.gfile.GFile(TEST_CSV, 'r') as temp_csv:
test_csv_contents = temp_csv.read() test_csv_contents = temp_csv.read()
# Used for end-to-end tests. # Used for end-to-end tests.
for fname in [census_dataset.TRAINING_FILE, census_dataset.EVAL_FILE]: for fname in [census_dataset.TRAINING_FILE, census_dataset.EVAL_FILE]:
with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv: with tf.io.gfile.GFile(
os.path.join(self.temp_dir, fname), 'w') as test_csv:
test_csv.write(test_csv_contents) test_csv.write(test_csv_contents)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_input_fn(self): def test_input_fn(self):
dataset = census_dataset.input_fn(self.input_csv, 1, False, 1) dataset = census_dataset.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next() features, labels = dataset.make_one_shot_iterator().get_next()
...@@ -123,9 +126,11 @@ class BaseTest(tf.test.TestCase): ...@@ -123,9 +126,11 @@ class BaseTest(tf.test.TestCase):
initial_results['auc_precision_recall']) initial_results['auc_precision_recall'])
self.assertGreater(final_results['accuracy'], initial_results['accuracy']) self.assertGreater(final_results['accuracy'], initial_results['accuracy'])
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_wide_deep_estimator_training(self): def test_wide_deep_estimator_training(self):
self.build_and_test_estimator('wide_deep') self.build_and_test_estimator('wide_deep')
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_wide(self): def test_end_to_end_wide(self):
integration.run_synthetic( integration.run_synthetic(
main=census_main.main, tmp_root=self.get_temp_dir(), main=census_main.main, tmp_root=self.get_temp_dir(),
...@@ -136,6 +141,7 @@ class BaseTest(tf.test.TestCase): ...@@ -136,6 +141,7 @@ class BaseTest(tf.test.TestCase):
], ],
synth=False, max_train=None) synth=False, max_train=None)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_deep(self): def test_end_to_end_deep(self):
integration.run_synthetic( integration.run_synthetic(
main=census_main.main, tmp_root=self.get_temp_dir(), main=census_main.main, tmp_root=self.get_temp_dir(),
...@@ -146,6 +152,7 @@ class BaseTest(tf.test.TestCase): ...@@ -146,6 +152,7 @@ class BaseTest(tf.test.TestCase):
], ],
synth=False, max_train=None) synth=False, max_train=None)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_wide_deep(self): def test_end_to_end_wide_deep(self):
integration.run_synthetic( integration.run_synthetic(
main=census_main.main, tmp_root=self.get_temp_dir(), main=census_main.main, tmp_root=self.get_temp_dir(),
......
...@@ -35,12 +35,14 @@ from official.utils.flags import core as flags_core ...@@ -35,12 +35,14 @@ from official.utils.flags import core as flags_core
_BUFFER_SUBDIR = "wide_deep_buffer" _BUFFER_SUBDIR = "wide_deep_buffer"
_FEATURE_MAP = { _FEATURE_MAP = {
movielens.USER_COLUMN: tf.FixedLenFeature([1], dtype=tf.int64), movielens.USER_COLUMN: tf.compat.v1.FixedLenFeature([1], dtype=tf.int64),
movielens.ITEM_COLUMN: tf.FixedLenFeature([1], dtype=tf.int64), movielens.ITEM_COLUMN: tf.compat.v1.FixedLenFeature([1], dtype=tf.int64),
movielens.TIMESTAMP_COLUMN: tf.FixedLenFeature([1], dtype=tf.int64), movielens.TIMESTAMP_COLUMN: tf.compat.v1.FixedLenFeature([1],
movielens.GENRE_COLUMN: tf.FixedLenFeature( dtype=tf.int64),
movielens.GENRE_COLUMN: tf.compat.v1.FixedLenFeature(
[movielens.N_GENRE], dtype=tf.int64), [movielens.N_GENRE], dtype=tf.int64),
movielens.RATING_COLUMN: tf.FixedLenFeature([1], dtype=tf.float32), movielens.RATING_COLUMN: tf.compat.v1.FixedLenFeature([1],
dtype=tf.float32),
} }
_BUFFER_SIZE = { _BUFFER_SIZE = {
......
...@@ -18,17 +18,18 @@ from __future__ import division ...@@ -18,17 +18,18 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
import numpy as np import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.datasets import movielens from official.datasets import movielens
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
from official.wide_deep import movielens_dataset from official.wide_deep import movielens_dataset
from official.wide_deep import movielens_main from official.wide_deep import movielens_main
from official.wide_deep import wide_deep_run_loop
tf.logging.set_verbosity(tf.logging.ERROR) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
TEST_INPUT_VALUES = { TEST_INPUT_VALUES = {
...@@ -70,20 +71,20 @@ class BaseTest(tf.test.TestCase): ...@@ -70,20 +71,20 @@ class BaseTest(tf.test.TestCase):
def setUp(self): def setUp(self):
# Create temporary CSV file # Create temporary CSV file
self.temp_dir = self.get_temp_dir() self.temp_dir = self.get_temp_dir()
tf.gfile.MakeDirs(os.path.join(self.temp_dir, movielens.ML_1M)) tf.io.gfile.makedirs(os.path.join(self.temp_dir, movielens.ML_1M))
self.ratings_csv = os.path.join( self.ratings_csv = os.path.join(
self.temp_dir, movielens.ML_1M, movielens.RATINGS_FILE) self.temp_dir, movielens.ML_1M, movielens.RATINGS_FILE)
self.item_csv = os.path.join( self.item_csv = os.path.join(
self.temp_dir, movielens.ML_1M, movielens.MOVIES_FILE) self.temp_dir, movielens.ML_1M, movielens.MOVIES_FILE)
with tf.gfile.Open(self.ratings_csv, "w") as f: with tf.io.gfile.GFile(self.ratings_csv, "w") as f:
f.write(TEST_RATING_DATA) f.write(TEST_RATING_DATA)
with tf.gfile.Open(self.item_csv, "w") as f: with tf.io.gfile.GFile(self.item_csv, "w") as f:
f.write(TEST_ITEM_DATA) f.write(TEST_ITEM_DATA)
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_input_fn(self): def test_input_fn(self):
train_input_fn, _, _ = movielens_dataset.construct_input_fns( train_input_fn, _, _ = movielens_dataset.construct_input_fns(
dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1) dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1)
...@@ -91,7 +92,7 @@ class BaseTest(tf.test.TestCase): ...@@ -91,7 +92,7 @@ class BaseTest(tf.test.TestCase):
dataset = train_input_fn() dataset = train_input_fn()
features, labels = dataset.make_one_shot_iterator().get_next() features, labels = dataset.make_one_shot_iterator().get_next()
with self.test_session() as sess: with self.session() as sess:
features, labels = sess.run((features, labels)) features, labels = sess.run((features, labels))
# Compare the two features dictionaries. # Compare the two features dictionaries.
...@@ -101,6 +102,7 @@ class BaseTest(tf.test.TestCase): ...@@ -101,6 +102,7 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(labels[0], [1.0]) self.assertAllClose(labels[0], [1.0])
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_end_to_end_deep(self): def test_end_to_end_deep(self):
integration.run_synthetic( integration.run_synthetic(
main=movielens_main.main, tmp_root=self.temp_dir, main=movielens_main.main, tmp_root=self.temp_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment