Commit b2c9e3f5 authored by Goldie Gadde's avatar Goldie Gadde Committed by Toby Boyd
Browse files

Revert "Revert "tf_upgrade_v2 on resnet and utils folders. (#6154)" (#6162)" (#6167)

This reverts commit 57e07520.
parent 57e07520
......@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
valid_flags = True
for key in flag_names:
if not flag_values[key].startswith("gs://"):
tf.logging.error("{} must be a GCS path.".format(key))
tf.compat.v1.logging.error("{} must be a GCS path.".format(key))
valid_flags = False
return valid_flags
......
......@@ -25,7 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import logger
class ExamplesPerSecondHook(tf.train.SessionRunHook):
class ExamplesPerSecondHook(tf.estimator.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
......@@ -66,7 +66,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._timer = tf.train.SecondOrStepTimer(
self._timer = tf.estimator.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
......@@ -76,7 +76,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
def begin(self):
"""Called once before using the session to check global step."""
self._global_step_tensor = tf.train.get_global_step()
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
......@@ -90,7 +90,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
Returns:
A SessionRunArgs object or None if never triggered.
"""
return tf.train.SessionRunArgs(self._global_step_tensor)
return tf.estimator.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run().
......
......@@ -57,7 +57,7 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return []
if use_tpu:
tf.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
tf.compat.v1.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used.".format(name_list))
return []
......@@ -89,7 +89,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return tf.train.LoggingTensorHook(
return tf.estimator.LoggingTensorHook(
tensors=tensors_to_log,
every_n_iter=every_n_iter)
......@@ -106,7 +106,7 @@ def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return tf.train.ProfilerHook(save_steps=save_steps, output_dir=model_dir)
return tf.estimator.ProfilerHook(save_steps=save_steps, output_dir=model_dir)
def get_examples_per_second_hook(every_n_steps=100,
......
......@@ -45,7 +45,7 @@ class BaseTest(unittest.TestCase):
returned_hook = hooks_helper.get_train_hooks(
[test_hook_name], model_dir="", **kwargs)
self.assertEqual(len(returned_hook), 1)
self.assertIsInstance(returned_hook[0], tf.train.SessionRunHook)
self.assertIsInstance(returned_hook[0], tf.estimator.SessionRunHook)
self.assertEqual(returned_hook[0].__class__.__name__.lower(),
expected_hook_name)
......
......@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import hooks
from official.utils.testing import mock_lib
tf.logging.set_verbosity(tf.logging.DEBUG)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
class ExamplesPerSecondHookTest(tf.test.TestCase):
......@@ -44,9 +44,10 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
self.graph = tf.Graph()
with self.graph.as_default():
tf.train.create_global_step()
self.train_op = tf.assign_add(tf.train.get_global_step(), 1)
self.global_step = tf.train.get_global_step()
tf.compat.v1.train.create_global_step()
self.train_op = tf.compat.v1.assign_add(
tf.compat.v1.train.get_global_step(), 1)
self.global_step = tf.compat.v1.train.get_global_step()
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
......@@ -71,8 +72,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
warm_steps=warm_steps,
metric_logger=self._logger)
with tf.train.MonitoredSession(
tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
with tf.compat.v1.train.MonitoredSession(
tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
for _ in range(every_n_steps):
# Explicitly run global_step after train_op to get the accurate
# global_step value
......@@ -125,8 +126,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_secs=every_n_secs,
metric_logger=self._logger)
with tf.train.MonitoredSession(
tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
with tf.compat.v1.train.MonitoredSession(
tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess.run(self.train_op)
......
......@@ -119,12 +119,13 @@ class BaseBenchmarkLogger(object):
eval_results: dict, the result of evaluate.
"""
if not isinstance(eval_results, dict):
tf.logging.warning("eval_results should be dictionary for logging. "
"Got %s", type(eval_results))
tf.compat.v1.logging.warning(
"eval_results should be dictionary for logging. Got %s",
type(eval_results))
return
global_step = eval_results[tf.GraphKeys.GLOBAL_STEP]
global_step = eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP]
for key in sorted(eval_results):
if key != tf.GraphKeys.GLOBAL_STEP:
if key != tf.compat.v1.GraphKeys.GLOBAL_STEP:
self.log_metric(key, eval_results[key], global_step=global_step)
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
......@@ -143,12 +144,12 @@ class BaseBenchmarkLogger(object):
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
tf.logging.info("Benchmark metric: %s", metric)
tf.compat.v1.logging.info("Benchmark metric: %s", metric)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
tf.logging.info("Benchmark run: %s",
_gather_run_info(model_name, dataset_name, run_params,
test_id))
tf.compat.v1.logging.info(
"Benchmark run: %s", _gather_run_info(model_name, dataset_name,
run_params, test_id))
def on_finish(self, status):
pass
......@@ -160,9 +161,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
def __init__(self, logging_dir):
super(BenchmarkFileLogger, self).__init__()
self._logging_dir = logging_dir
if not tf.gfile.IsDirectory(self._logging_dir):
tf.gfile.MakeDirs(self._logging_dir)
self._metric_file_handler = tf.gfile.GFile(
if not tf.io.gfile.isdir(self._logging_dir):
tf.io.gfile.makedirs(self._logging_dir)
self._metric_file_handler = tf.io.gfile.GFile(
os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a")
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
......@@ -186,8 +187,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
self._metric_file_handler.write("\n")
self._metric_file_handler.flush()
except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump metric to log file: "
"name %s, value %s, error %s", name, value, e)
tf.compat.v1.logging.warning(
"Failed to dump metric to log file: name %s, value %s, error %s",
name, value, e)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
......@@ -204,14 +206,14 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
"""
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
with tf.gfile.GFile(os.path.join(
with tf.io.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
try:
json.dump(run_info, f)
f.write("\n")
except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump benchmark run info to log file: %s",
e)
tf.compat.v1.logging.warning(
"Failed to dump benchmark run info to log file: %s", e)
def on_finish(self, status):
self._metric_file_handler.flush()
......@@ -324,7 +326,7 @@ def _process_metric_to_json(
name, value, unit=None, global_step=None, extras=None):
"""Validate the metric data and generate JSON for insert."""
if not isinstance(value, numbers.Number):
tf.logging.warning(
tf.compat.v1.logging.warning(
"Metric value to log should be a number. Got %s", type(value))
return None
......@@ -341,7 +343,7 @@ def _process_metric_to_json(
def _collect_tensorflow_info(run_info):
run_info["tensorflow_version"] = {
"version": tf.VERSION, "git_hash": tf.GIT_VERSION}
"version": tf.version.VERSION, "git_hash": tf.version.GIT_VERSION}
def _collect_run_params(run_info, run_params):
......@@ -385,7 +387,8 @@ def _collect_cpu_info(run_info):
run_info["machine_config"]["cpu_info"] = cpu_info
except ImportError:
tf.logging.warn("'cpuinfo' not imported. CPU info will not be logged.")
tf.compat.v1.logging.warn(
"'cpuinfo' not imported. CPU info will not be logged.")
def _collect_gpu_info(run_info, session_config=None):
......@@ -415,7 +418,8 @@ def _collect_memory_info(run_info):
run_info["machine_config"]["memory_total"] = vmem.total
run_info["machine_config"]["memory_available"] = vmem.available
except ImportError:
tf.logging.warn("'psutil' not imported. Memory info will not be logged.")
tf.compat.v1.logging.warn(
"'psutil' not imported. Memory info will not be logged.")
def _collect_test_environment(run_info):
......
......@@ -78,7 +78,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger
with logger.benchmark_context(None):
tf.logging.info("start benchmarking")
tf.compat.v1.logging.info("start benchmarking")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS)
@mock.patch("official.utils.logs.logger.config_benchmark_logger")
......@@ -95,18 +95,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def setUp(self):
super(BaseBenchmarkLoggerTest, self).setUp()
self._actual_log = tf.logging.info
self._actual_log = tf.compat.v1.logging.info
self.logged_message = None
def mock_log(*args, **kwargs):
self.logged_message = args
self._actual_log(*args, **kwargs)
tf.logging.info = mock_log
tf.compat.v1.logging.info = mock_log
def tearDown(self):
super(BaseBenchmarkLoggerTest, self).tearDown()
tf.logging.info = self._actual_log
tf.compat.v1.logging.info = self._actual_log
def test_log_metric(self):
log = logger.BaseBenchmarkLogger()
......@@ -128,16 +128,16 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
def tearDown(self):
super(BenchmarkFileLoggerTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
tf.io.gfile.rmtree(self.get_temp_dir())
os.environ.clear()
os.environ.update(self.original_environ)
def test_create_logging_dir(self):
non_exist_temp_dir = os.path.join(self.get_temp_dir(), "unknown_dir")
self.assertFalse(tf.gfile.IsDirectory(non_exist_temp_dir))
self.assertFalse(tf.io.gfile.isdir(non_exist_temp_dir))
logger.BenchmarkFileLogger(non_exist_temp_dir)
self.assertTrue(tf.gfile.IsDirectory(non_exist_temp_dir))
self.assertTrue(tf.io.gfile.isdir(non_exist_temp_dir))
def test_log_metric(self):
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
......@@ -145,8 +145,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_metric("accuracy", 0.999, global_step=1e4, extras={"name": "value"})
metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.gfile.Exists(metric_log))
with tf.gfile.GFile(metric_log) as f:
self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.io.gfile.GFile(metric_log) as f:
metric = json.loads(f.readline())
self.assertEqual(metric["name"], "accuracy")
self.assertEqual(metric["value"], 0.999)
......@@ -161,8 +161,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_metric("loss", 0.02, global_step=1e4)
metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.gfile.Exists(metric_log))
with tf.gfile.GFile(metric_log) as f:
self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.io.gfile.GFile(metric_log) as f:
accuracy = json.loads(f.readline())
self.assertEqual(accuracy["name"], "accuracy")
self.assertEqual(accuracy["value"], 0.999)
......@@ -184,7 +184,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_metric("accuracy", const)
metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.gfile.Exists(metric_log))
self.assertFalse(tf.io.gfile.exists(metric_log))
def test_log_evaluation_result(self):
eval_result = {"loss": 0.46237424,
......@@ -195,8 +195,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_evaluation_result(eval_result)
metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.gfile.Exists(metric_log))
with tf.gfile.GFile(metric_log) as f:
self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.io.gfile.GFile(metric_log) as f:
accuracy = json.loads(f.readline())
self.assertEqual(accuracy["name"], "accuracy")
self.assertEqual(accuracy["value"], 0.9285)
......@@ -216,7 +216,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_evaluation_result(eval_result)
metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.gfile.Exists(metric_log))
self.assertFalse(tf.io.gfile.exists(metric_log))
@mock.patch("official.utils.logs.logger._gather_run_info")
def test_log_run_info(self, mock_gather_run_info):
......@@ -229,8 +229,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_run_info("model_name", "dataset_name", {})
run_log = os.path.join(log_dir, "benchmark_run.log")
self.assertTrue(tf.gfile.Exists(run_log))
with tf.gfile.GFile(run_log) as f:
self.assertTrue(tf.io.gfile.exists(run_log))
with tf.io.gfile.GFile(run_log) as f:
run_info = json.loads(f.readline())
self.assertEqual(run_info["model_name"], "model_name")
self.assertEqual(run_info["dataset"], "dataset_name")
......@@ -240,8 +240,10 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
run_info = {}
logger._collect_tensorflow_info(run_info)
self.assertNotEqual(run_info["tensorflow_version"], {})
self.assertEqual(run_info["tensorflow_version"]["version"], tf.VERSION)
self.assertEqual(run_info["tensorflow_version"]["git_hash"], tf.GIT_VERSION)
self.assertEqual(run_info["tensorflow_version"]["version"],
tf.version.VERSION)
self.assertEqual(run_info["tensorflow_version"]["git_hash"],
tf.version.GIT_VERSION)
def test_collect_run_params(self):
run_info = {}
......@@ -315,7 +317,7 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
def tearDown(self):
super(BenchmarkBigQueryLoggerTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
tf.io.gfile.rmtree(self.get_temp_dir())
os.environ.clear()
os.environ.update(self.original_environ)
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
class LoggingMetricHook(tf.train.LoggingTensorHook):
class LoggingMetricHook(tf.estimator.LoggingTensorHook):
"""Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given
......@@ -68,7 +68,7 @@ class LoggingMetricHook(tf.train.LoggingTensorHook):
def begin(self):
super(LoggingMetricHook, self).begin()
self._global_step_tensor = tf.train.get_global_step()
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use LoggingMetricHook.")
......
......@@ -39,7 +39,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
def tearDown(self):
super(LoggingMetricHookTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
tf.io.gfile.rmtree(self.get_temp_dir())
def test_illegal_args(self):
with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
......@@ -55,15 +55,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
def test_print_at_end_only(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
t = tf.constant(42.0, name="foo")
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
for _ in range(3):
mon_sess.run(train_op)
......@@ -88,8 +88,8 @@ class LoggingMetricHookTest(tf.test.TestCase):
hook.begin()
def test_log_tensors(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
t1 = tf.constant(42.0, name="foo")
t2 = tf.constant(43.0, name="bar")
train_op = tf.constant(3)
......@@ -97,7 +97,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
tensors=[t1, t2], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
for _ in range(3):
mon_sess.run(train_op)
......@@ -126,7 +126,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
for _ in range(3):
......@@ -153,15 +153,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_print_every_n_steps(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=False)
# Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=False)
def test_print_every_n_steps_and_end(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=True)
# Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=True)
......@@ -175,7 +175,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
......@@ -199,15 +199,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_print_every_n_secs(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=False)
# Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=False)
def test_print_every_n_secs_and_end(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=True)
# Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=True)
......
......@@ -94,7 +94,7 @@ def get_mlperf_log():
version = pkg_resources.get_distribution("mlperf_compliance")
version = tuple(int(i) for i in version.version.split("."))
if version < _MIN_VERSION:
tf.logging.warning(
tf.compat.v1.logging.warning(
"mlperf_compliance is version {}, must be >= {}".format(
".".join([str(i) for i in version]),
".".join([str(i) for i in _MIN_VERSION])))
......@@ -187,6 +187,6 @@ def clear_system_caches():
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
with LOGGER(True):
ncf_print(key=TAGS.RUN_START)
......@@ -48,7 +48,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number.")
if eval_metric >= stop_threshold:
tf.logging.info(
tf.compat.v1.logging.info(
"Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric))
return True
......@@ -87,7 +87,7 @@ def generate_synthetic_data(
def apply_clean(flags_obj):
if flags_obj.clean and tf.gfile.Exists(flags_obj.model_dir):
tf.logging.info("--clean flag set. Removing existing model dir: {}".format(
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir: {}".format(
flags_obj.model_dir))
tf.gfile.DeleteRecursively(flags_obj.model_dir)
tf.io.gfile.rmtree(flags_obj.model_dir)
......@@ -69,13 +69,13 @@ class SyntheticDataTest(tf.test.TestCase):
"""Tests for generate_synthetic_data."""
def test_generate_synethetic_data(self):
input_element, label_element = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([5]),
input_value=123,
input_dtype=tf.float32,
label_shape=tf.TensorShape([]),
label_value=456,
label_dtype=tf.int32).make_one_shot_iterator().get_next()
input_element, label_element = tf.compat.v1.data.make_one_shot_iterator(
model_helpers.generate_synthetic_data(input_shape=tf.TensorShape([5]),
input_value=123,
input_dtype=tf.float32,
label_shape=tf.TensorShape([]),
label_value=456,
label_dtype=tf.int32)).get_next()
with self.test_session() as sess:
for n in range(5):
......@@ -89,7 +89,7 @@ class SyntheticDataTest(tf.test.TestCase):
input_value=43.5,
input_dtype=tf.float32)
element = d.make_one_shot_iterator().get_next()
element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
self.assertFalse(isinstance(element, tuple))
with self.test_session() as sess:
......@@ -102,7 +102,7 @@ class SyntheticDataTest(tf.test.TestCase):
'b': {'c': tf.TensorShape([3]), 'd': tf.TensorShape([])}},
input_value=1.1)
element = d.make_one_shot_iterator().get_next()
element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
self.assertIn('a', element)
self.assertIn('b', element)
self.assertEquals(len(element['b']), 2)
......
......@@ -170,12 +170,12 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph")
with tf.gfile.Open(expected_file, "wb") as f:
with tf.io.gfile.GFile(expected_file, "wb") as f:
f.write(graph_bytes)
with graph.as_default():
init = tf.global_variables_initializer()
saver = tf.train.Saver()
init = tf.compat.v1.global_variables_initializer()
saver = tf.compat.v1.train.Saver()
with self.test_session(graph=graph) as sess:
sess.run(init)
......@@ -191,11 +191,11 @@ class BaseTest(tf.test.TestCase):
if correctness_function is not None:
results = correctness_function(*eval_results)
with tf.gfile.Open(os.path.join(data_dir, "results.json"), "w") as f:
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "w") as f:
json.dump(results, f)
with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "w") as f:
json.dump([tf.VERSION, tf.GIT_VERSION], f)
with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "w") as f:
json.dump([tf.version.VERSION, tf.version.GIT_VERSION], f)
def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
"""Determine if a graph agrees with the reference data.
......@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph")
with tf.gfile.Open(expected_file, "rb") as f:
with tf.io.gfile.GFile(expected_file, "rb") as f:
expected_graph_bytes = f.read()
# The serialization is non-deterministic byte-for-byte. Instead there is
# a utility which evaluates the semantics of the two graphs to test for
......@@ -228,19 +228,19 @@ class BaseTest(tf.test.TestCase):
graph_bytes, expected_graph_bytes).decode("utf-8")
with graph.as_default():
init = tf.global_variables_initializer()
saver = tf.train.Saver()
init = tf.compat.v1.global_variables_initializer()
saver = tf.compat.v1.train.Saver()
with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "r") as f:
with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "r") as f:
tf_version_reference, tf_git_version_reference = json.load(f) # pylint: disable=unpacking-non-sequence
tf_version_comparison = ""
if tf.GIT_VERSION != tf_git_version_reference:
if tf.version.GIT_VERSION != tf_git_version_reference:
tf_version_comparison = (
"Test was built using: {} (git = {})\n"
"Local TensorFlow version: {} (git = {})"
.format(tf_version_reference, tf_git_version_reference,
tf.VERSION, tf.GIT_VERSION)
tf.version.VERSION, tf.version.GIT_VERSION)
)
with self.test_session(graph=graph) as sess:
......@@ -249,7 +249,7 @@ class BaseTest(tf.test.TestCase):
saver.restore(sess=sess, save_path=os.path.join(
data_dir, self.ckpt_prefix))
if differences:
tf.logging.warn(
tf.compat.v1.logging.warn(
"The provided graph is different than expected:\n {}\n"
"However the weights were still able to be loaded.\n{}".format(
differences, tf_version_comparison)
......@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase):
eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None:
results = correctness_function(*eval_results)
with tf.gfile.Open(os.path.join(data_dir, "results.json"), "r") as f:
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "r") as f:
expected_results = json.load(f)
self.assertAllClose(results, expected_results)
......@@ -298,7 +298,7 @@ class BaseTest(tf.test.TestCase):
correctness_function=correctness_function
)
except:
tf.logging.error("Failed unittest {}".format(name))
tf.compat.v1.logging.error("Failed unittest {}".format(name))
raise
else:
self._construct_and_save_reference_files(
......
......@@ -63,12 +63,12 @@ class GoldenBaseTest(reference_data.BaseTest):
with g.as_default():
seed = self.name_to_seed(name)
seed = seed + 1 if bad_seed else seed
tf.set_random_seed(seed)
tf.compat.v1.set_random_seed(seed)
tensor_name = "wrong_tensor" if wrong_name else "input_tensor"
tensor_shape = (1, 2) if wrong_shape else (1, 1)
input_tensor = tf.get_variable(
input_tensor = tf.compat.v1.get_variable(
tensor_name, dtype=tf.float32,
initializer=tf.random_uniform(tensor_shape, maxval=1)
initializer=tf.random.uniform(tensor_shape, maxval=1)
)
def correctness_function(tensor_result):
......@@ -86,13 +86,13 @@ class GoldenBaseTest(reference_data.BaseTest):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(self.name_to_seed(name))
input_tensor = tf.get_variable(
tf.compat.v1.set_random_seed(self.name_to_seed(name))
input_tensor = tf.compat.v1.get_variable(
"input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((1, 2), maxval=1)
initializer=tf.random.uniform((1, 2), maxval=1)
)
layer = tf.layers.dense(inputs=input_tensor, units=4)
layer = tf.layers.dense(inputs=layer, units=1)
layer = tf.compat.v1.layers.dense(inputs=input_tensor, units=4)
layer = tf.compat.v1.layers.dense(inputs=layer, units=1)
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[layer], test=test,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment