"tests/vscode:/vscode.git/clone" did not exist on "64a5187d96f9376c7cf5123db810f2d2da79d7d0"
Commit d6b2b83c authored by Goldie Gadde's avatar Goldie Gadde Committed by Toby Boyd
Browse files

tf_upgrade_v2 on resnet and utils folders. (#6154)

* Add resnet56 short tests. (#6101)

* Add resnet56 short tests.
- created base benchmark module
- renamed accuracy test class to contain the word Accuracy
which will result in a need to update all the jobs
and a loss of history but is worth it.
- short tests are mostly copied from shining with oss refactor

* Address feedback.

* Move flag_methods to init
- Address setting default flags repeatedly.

* Rename accuracy tests.

* Lint errors resolved.

* fix model_dir set to flags.data_dir.

* fixed not fulling pulling out flag_methods.

* Use core mirrored strategy in official models (#6126)

* Imagenet short tests (#6132)

* Add short imagenet tests (taken from seemuch)
- also rename to match go forward naming

* fix method name

* Update doc strings.

* Fixe gpu number.

* points default data_dir to child folder. (#6131)

Failed test is python2  and was a kokoro failure

* Imagenet short tests (#6136)

* Add short imagenet tests (taken from seemuch)
- also rename to match go forward naming

* fix method name

* Update doc strings.

* Fixe gpu number.

* Add fill_objects

* fixed calling wrong class in super.

* fix lint issue.

* Flag (#6121)

* Fix the turn_off_ds flag problem

* add param names to all args

* Export benchmark stats using tf.test.Benchmark.report_benchmark() (#6103)

* Export benchmark stats using tf.test.Benchmark.report_benchmark()

* Fix python style using pyformat

* Typos. (#6120)

* log verbosity=2 logs every epoch no progress bars (#6142)

* tf_upgrade_v2 on resnet and utils folder.

* tf_upgrade_v2 on resnet and utils folder.
parent 722f345e
...@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names): ...@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
valid_flags = True valid_flags = True
for key in flag_names: for key in flag_names:
if not flag_values[key].startswith("gs://"): if not flag_values[key].startswith("gs://"):
tf.logging.error("{} must be a GCS path.".format(key)) tf.compat.v1.logging.error("{} must be a GCS path.".format(key))
valid_flags = False valid_flags = False
return valid_flags return valid_flags
......
...@@ -25,7 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -25,7 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import logger from official.utils.logs import logger
class ExamplesPerSecondHook(tf.train.SessionRunHook): class ExamplesPerSecondHook(tf.estimator.SessionRunHook):
"""Hook to print out examples per second. """Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps Total time is tracked and then divided by the total number of steps
...@@ -66,7 +66,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -66,7 +66,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
self._logger = metric_logger or logger.BaseBenchmarkLogger() self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._timer = tf.train.SecondOrStepTimer( self._timer = tf.estimator.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs) every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0 self._step_train_time = 0
...@@ -76,7 +76,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -76,7 +76,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
def begin(self): def begin(self):
"""Called once before using the session to check global step.""" """Called once before using the session to check global step."""
self._global_step_tensor = tf.train.get_global_step() self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None: if self._global_step_tensor is None:
raise RuntimeError( raise RuntimeError(
"Global step should be created to use StepCounterHook.") "Global step should be created to use StepCounterHook.")
...@@ -90,7 +90,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -90,7 +90,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
Returns: Returns:
A SessionRunArgs object or None if never triggered. A SessionRunArgs object or None if never triggered.
""" """
return tf.train.SessionRunArgs(self._global_step_tensor) return tf.estimator.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values): # pylint: disable=unused-argument def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run(). """Called after each call to run().
......
...@@ -57,7 +57,7 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs): ...@@ -57,7 +57,7 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return [] return []
if use_tpu: if use_tpu:
tf.logging.warning("hooks_helper received name_list `{}`, but a TPU is " tf.compat.v1.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used.".format(name_list)) "specified. No hooks will be used.".format(name_list))
return [] return []
...@@ -89,7 +89,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # ...@@ -89,7 +89,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
if tensors_to_log is None: if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG tensors_to_log = _TENSORS_TO_LOG
return tf.train.LoggingTensorHook( return tf.estimator.LoggingTensorHook(
tensors=tensors_to_log, tensors=tensors_to_log,
every_n_iter=every_n_iter) every_n_iter=every_n_iter)
...@@ -106,7 +106,7 @@ def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable= ...@@ -106,7 +106,7 @@ def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=
Returns a ProfilerHook that writes out timelines that can be loaded into Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing. profiling tools like chrome://tracing.
""" """
return tf.train.ProfilerHook(save_steps=save_steps, output_dir=model_dir) return tf.estimator.ProfilerHook(save_steps=save_steps, output_dir=model_dir)
def get_examples_per_second_hook(every_n_steps=100, def get_examples_per_second_hook(every_n_steps=100,
......
...@@ -45,7 +45,7 @@ class BaseTest(unittest.TestCase): ...@@ -45,7 +45,7 @@ class BaseTest(unittest.TestCase):
returned_hook = hooks_helper.get_train_hooks( returned_hook = hooks_helper.get_train_hooks(
[test_hook_name], model_dir="", **kwargs) [test_hook_name], model_dir="", **kwargs)
self.assertEqual(len(returned_hook), 1) self.assertEqual(len(returned_hook), 1)
self.assertIsInstance(returned_hook[0], tf.train.SessionRunHook) self.assertIsInstance(returned_hook[0], tf.estimator.SessionRunHook)
self.assertEqual(returned_hook[0].__class__.__name__.lower(), self.assertEqual(returned_hook[0].__class__.__name__.lower(),
expected_hook_name) expected_hook_name)
......
...@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import hooks from official.utils.logs import hooks
from official.utils.testing import mock_lib from official.utils.testing import mock_lib
tf.logging.set_verbosity(tf.logging.DEBUG) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
class ExamplesPerSecondHookTest(tf.test.TestCase): class ExamplesPerSecondHookTest(tf.test.TestCase):
...@@ -44,9 +44,10 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -44,9 +44,10 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
tf.train.create_global_step() tf.compat.v1.train.create_global_step()
self.train_op = tf.assign_add(tf.train.get_global_step(), 1) self.train_op = tf.compat.v1.assign_add(
self.global_step = tf.train.get_global_step() tf.compat.v1.train.get_global_step(), 1)
self.global_step = tf.compat.v1.train.get_global_step()
def test_raise_in_both_secs_and_steps(self): def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -71,8 +72,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -71,8 +72,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
warm_steps=warm_steps, warm_steps=warm_steps,
metric_logger=self._logger) metric_logger=self._logger)
with tf.train.MonitoredSession( with tf.compat.v1.train.MonitoredSession(
tf.train.ChiefSessionCreator(), [hook]) as mon_sess: tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
for _ in range(every_n_steps): for _ in range(every_n_steps):
# Explicitly run global_step after train_op to get the accurate # Explicitly run global_step after train_op to get the accurate
# global_step value # global_step value
...@@ -125,8 +126,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -125,8 +126,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_secs=every_n_secs, every_n_secs=every_n_secs,
metric_logger=self._logger) metric_logger=self._logger)
with tf.train.MonitoredSession( with tf.compat.v1.train.MonitoredSession(
tf.train.ChiefSessionCreator(), [hook]) as mon_sess: tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
# Explicitly run global_step after train_op to get the accurate # Explicitly run global_step after train_op to get the accurate
# global_step value # global_step value
mon_sess.run(self.train_op) mon_sess.run(self.train_op)
......
...@@ -119,12 +119,13 @@ class BaseBenchmarkLogger(object): ...@@ -119,12 +119,13 @@ class BaseBenchmarkLogger(object):
eval_results: dict, the result of evaluate. eval_results: dict, the result of evaluate.
""" """
if not isinstance(eval_results, dict): if not isinstance(eval_results, dict):
tf.logging.warning("eval_results should be dictionary for logging. " tf.compat.v1.logging.warning(
"Got %s", type(eval_results)) "eval_results should be dictionary for logging. Got %s",
type(eval_results))
return return
global_step = eval_results[tf.GraphKeys.GLOBAL_STEP] global_step = eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP]
for key in sorted(eval_results): for key in sorted(eval_results):
if key != tf.GraphKeys.GLOBAL_STEP: if key != tf.compat.v1.GraphKeys.GLOBAL_STEP:
self.log_metric(key, eval_results[key], global_step=global_step) self.log_metric(key, eval_results[key], global_step=global_step)
def log_metric(self, name, value, unit=None, global_step=None, extras=None): def log_metric(self, name, value, unit=None, global_step=None, extras=None):
...@@ -143,12 +144,12 @@ class BaseBenchmarkLogger(object): ...@@ -143,12 +144,12 @@ class BaseBenchmarkLogger(object):
""" """
metric = _process_metric_to_json(name, value, unit, global_step, extras) metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric: if metric:
tf.logging.info("Benchmark metric: %s", metric) tf.compat.v1.logging.info("Benchmark metric: %s", metric)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None): def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
tf.logging.info("Benchmark run: %s", tf.compat.v1.logging.info(
_gather_run_info(model_name, dataset_name, run_params, "Benchmark run: %s", _gather_run_info(model_name, dataset_name,
test_id)) run_params, test_id))
def on_finish(self, status): def on_finish(self, status):
pass pass
...@@ -160,9 +161,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger): ...@@ -160,9 +161,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
def __init__(self, logging_dir): def __init__(self, logging_dir):
super(BenchmarkFileLogger, self).__init__() super(BenchmarkFileLogger, self).__init__()
self._logging_dir = logging_dir self._logging_dir = logging_dir
if not tf.gfile.IsDirectory(self._logging_dir): if not tf.io.gfile.isdir(self._logging_dir):
tf.gfile.MakeDirs(self._logging_dir) tf.io.gfile.makedirs(self._logging_dir)
self._metric_file_handler = tf.gfile.GFile( self._metric_file_handler = tf.io.gfile.GFile(
os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a") os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a")
def log_metric(self, name, value, unit=None, global_step=None, extras=None): def log_metric(self, name, value, unit=None, global_step=None, extras=None):
...@@ -186,8 +187,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger): ...@@ -186,8 +187,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
self._metric_file_handler.write("\n") self._metric_file_handler.write("\n")
self._metric_file_handler.flush() self._metric_file_handler.flush()
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump metric to log file: " tf.compat.v1.logging.warning(
"name %s, value %s, error %s", name, value, e) "Failed to dump metric to log file: name %s, value %s, error %s",
name, value, e)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None): def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env. """Collect most of the TF runtime information for the local env.
...@@ -204,14 +206,14 @@ class BenchmarkFileLogger(BaseBenchmarkLogger): ...@@ -204,14 +206,14 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
""" """
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id) run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
with tf.gfile.GFile(os.path.join( with tf.io.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f: self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
try: try:
json.dump(run_info, f) json.dump(run_info, f)
f.write("\n") f.write("\n")
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump benchmark run info to log file: %s", tf.compat.v1.logging.warning(
e) "Failed to dump benchmark run info to log file: %s", e)
def on_finish(self, status): def on_finish(self, status):
self._metric_file_handler.flush() self._metric_file_handler.flush()
...@@ -324,7 +326,7 @@ def _process_metric_to_json( ...@@ -324,7 +326,7 @@ def _process_metric_to_json(
name, value, unit=None, global_step=None, extras=None): name, value, unit=None, global_step=None, extras=None):
"""Validate the metric data and generate JSON for insert.""" """Validate the metric data and generate JSON for insert."""
if not isinstance(value, numbers.Number): if not isinstance(value, numbers.Number):
tf.logging.warning( tf.compat.v1.logging.warning(
"Metric value to log should be a number. Got %s", type(value)) "Metric value to log should be a number. Got %s", type(value))
return None return None
...@@ -341,7 +343,7 @@ def _process_metric_to_json( ...@@ -341,7 +343,7 @@ def _process_metric_to_json(
def _collect_tensorflow_info(run_info): def _collect_tensorflow_info(run_info):
run_info["tensorflow_version"] = { run_info["tensorflow_version"] = {
"version": tf.VERSION, "git_hash": tf.GIT_VERSION} "version": tf.version.VERSION, "git_hash": tf.version.GIT_VERSION}
def _collect_run_params(run_info, run_params): def _collect_run_params(run_info, run_params):
...@@ -385,7 +387,8 @@ def _collect_cpu_info(run_info): ...@@ -385,7 +387,8 @@ def _collect_cpu_info(run_info):
run_info["machine_config"]["cpu_info"] = cpu_info run_info["machine_config"]["cpu_info"] = cpu_info
except ImportError: except ImportError:
tf.logging.warn("'cpuinfo' not imported. CPU info will not be logged.") tf.compat.v1.logging.warn(
"'cpuinfo' not imported. CPU info will not be logged.")
def _collect_gpu_info(run_info, session_config=None): def _collect_gpu_info(run_info, session_config=None):
...@@ -415,7 +418,8 @@ def _collect_memory_info(run_info): ...@@ -415,7 +418,8 @@ def _collect_memory_info(run_info):
run_info["machine_config"]["memory_total"] = vmem.total run_info["machine_config"]["memory_total"] = vmem.total
run_info["machine_config"]["memory_available"] = vmem.available run_info["machine_config"]["memory_available"] = vmem.available
except ImportError: except ImportError:
tf.logging.warn("'psutil' not imported. Memory info will not be logged.") tf.compat.v1.logging.warn(
"'psutil' not imported. Memory info will not be logged.")
def _collect_test_environment(run_info): def _collect_test_environment(run_info):
......
...@@ -78,7 +78,7 @@ class BenchmarkLoggerTest(tf.test.TestCase): ...@@ -78,7 +78,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger = mock.MagicMock() mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger mock_config_benchmark_logger.return_value = mock_logger
with logger.benchmark_context(None): with logger.benchmark_context(None):
tf.logging.info("start benchmarking") tf.compat.v1.logging.info("start benchmarking")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS) mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS)
@mock.patch("official.utils.logs.logger.config_benchmark_logger") @mock.patch("official.utils.logs.logger.config_benchmark_logger")
...@@ -95,18 +95,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase): ...@@ -95,18 +95,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(BaseBenchmarkLoggerTest, self).setUp() super(BaseBenchmarkLoggerTest, self).setUp()
self._actual_log = tf.logging.info self._actual_log = tf.compat.v1.logging.info
self.logged_message = None self.logged_message = None
def mock_log(*args, **kwargs): def mock_log(*args, **kwargs):
self.logged_message = args self.logged_message = args
self._actual_log(*args, **kwargs) self._actual_log(*args, **kwargs)
tf.logging.info = mock_log tf.compat.v1.logging.info = mock_log
def tearDown(self): def tearDown(self):
super(BaseBenchmarkLoggerTest, self).tearDown() super(BaseBenchmarkLoggerTest, self).tearDown()
tf.logging.info = self._actual_log tf.compat.v1.logging.info = self._actual_log
def test_log_metric(self): def test_log_metric(self):
log = logger.BaseBenchmarkLogger() log = logger.BaseBenchmarkLogger()
...@@ -128,16 +128,16 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -128,16 +128,16 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
def tearDown(self): def tearDown(self):
super(BenchmarkFileLoggerTest, self).tearDown() super(BenchmarkFileLoggerTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
os.environ.clear() os.environ.clear()
os.environ.update(self.original_environ) os.environ.update(self.original_environ)
def test_create_logging_dir(self): def test_create_logging_dir(self):
non_exist_temp_dir = os.path.join(self.get_temp_dir(), "unknown_dir") non_exist_temp_dir = os.path.join(self.get_temp_dir(), "unknown_dir")
self.assertFalse(tf.gfile.IsDirectory(non_exist_temp_dir)) self.assertFalse(tf.io.gfile.isdir(non_exist_temp_dir))
logger.BenchmarkFileLogger(non_exist_temp_dir) logger.BenchmarkFileLogger(non_exist_temp_dir)
self.assertTrue(tf.gfile.IsDirectory(non_exist_temp_dir)) self.assertTrue(tf.io.gfile.isdir(non_exist_temp_dir))
def test_log_metric(self): def test_log_metric(self):
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
...@@ -145,8 +145,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -145,8 +145,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_metric("accuracy", 0.999, global_step=1e4, extras={"name": "value"}) log.log_metric("accuracy", 0.999, global_step=1e4, extras={"name": "value"})
metric_log = os.path.join(log_dir, "metric.log") metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.gfile.Exists(metric_log)) self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.gfile.GFile(metric_log) as f: with tf.io.gfile.GFile(metric_log) as f:
metric = json.loads(f.readline()) metric = json.loads(f.readline())
self.assertEqual(metric["name"], "accuracy") self.assertEqual(metric["name"], "accuracy")
self.assertEqual(metric["value"], 0.999) self.assertEqual(metric["value"], 0.999)
...@@ -161,8 +161,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -161,8 +161,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_metric("loss", 0.02, global_step=1e4) log.log_metric("loss", 0.02, global_step=1e4)
metric_log = os.path.join(log_dir, "metric.log") metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.gfile.Exists(metric_log)) self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.gfile.GFile(metric_log) as f: with tf.io.gfile.GFile(metric_log) as f:
accuracy = json.loads(f.readline()) accuracy = json.loads(f.readline())
self.assertEqual(accuracy["name"], "accuracy") self.assertEqual(accuracy["name"], "accuracy")
self.assertEqual(accuracy["value"], 0.999) self.assertEqual(accuracy["value"], 0.999)
...@@ -184,7 +184,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -184,7 +184,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_metric("accuracy", const) log.log_metric("accuracy", const)
metric_log = os.path.join(log_dir, "metric.log") metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.gfile.Exists(metric_log)) self.assertFalse(tf.io.gfile.exists(metric_log))
def test_log_evaluation_result(self): def test_log_evaluation_result(self):
eval_result = {"loss": 0.46237424, eval_result = {"loss": 0.46237424,
...@@ -195,8 +195,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -195,8 +195,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_evaluation_result(eval_result) log.log_evaluation_result(eval_result)
metric_log = os.path.join(log_dir, "metric.log") metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.gfile.Exists(metric_log)) self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.gfile.GFile(metric_log) as f: with tf.io.gfile.GFile(metric_log) as f:
accuracy = json.loads(f.readline()) accuracy = json.loads(f.readline())
self.assertEqual(accuracy["name"], "accuracy") self.assertEqual(accuracy["name"], "accuracy")
self.assertEqual(accuracy["value"], 0.9285) self.assertEqual(accuracy["value"], 0.9285)
...@@ -216,7 +216,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_evaluation_result(eval_result) log.log_evaluation_result(eval_result)
metric_log = os.path.join(log_dir, "metric.log") metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.gfile.Exists(metric_log)) self.assertFalse(tf.io.gfile.exists(metric_log))
@mock.patch("official.utils.logs.logger._gather_run_info") @mock.patch("official.utils.logs.logger._gather_run_info")
def test_log_run_info(self, mock_gather_run_info): def test_log_run_info(self, mock_gather_run_info):
...@@ -229,8 +229,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -229,8 +229,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log.log_run_info("model_name", "dataset_name", {}) log.log_run_info("model_name", "dataset_name", {})
run_log = os.path.join(log_dir, "benchmark_run.log") run_log = os.path.join(log_dir, "benchmark_run.log")
self.assertTrue(tf.gfile.Exists(run_log)) self.assertTrue(tf.io.gfile.exists(run_log))
with tf.gfile.GFile(run_log) as f: with tf.io.gfile.GFile(run_log) as f:
run_info = json.loads(f.readline()) run_info = json.loads(f.readline())
self.assertEqual(run_info["model_name"], "model_name") self.assertEqual(run_info["model_name"], "model_name")
self.assertEqual(run_info["dataset"], "dataset_name") self.assertEqual(run_info["dataset"], "dataset_name")
...@@ -240,8 +240,10 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -240,8 +240,10 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
run_info = {} run_info = {}
logger._collect_tensorflow_info(run_info) logger._collect_tensorflow_info(run_info)
self.assertNotEqual(run_info["tensorflow_version"], {}) self.assertNotEqual(run_info["tensorflow_version"], {})
self.assertEqual(run_info["tensorflow_version"]["version"], tf.VERSION) self.assertEqual(run_info["tensorflow_version"]["version"],
self.assertEqual(run_info["tensorflow_version"]["git_hash"], tf.GIT_VERSION) tf.version.VERSION)
self.assertEqual(run_info["tensorflow_version"]["git_hash"],
tf.version.GIT_VERSION)
def test_collect_run_params(self): def test_collect_run_params(self):
run_info = {} run_info = {}
...@@ -315,7 +317,7 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase): ...@@ -315,7 +317,7 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
def tearDown(self): def tearDown(self):
super(BenchmarkBigQueryLoggerTest, self).tearDown() super(BenchmarkBigQueryLoggerTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
os.environ.clear() os.environ.clear()
os.environ.update(self.original_environ) os.environ.update(self.original_environ)
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
class LoggingMetricHook(tf.train.LoggingTensorHook): class LoggingMetricHook(tf.estimator.LoggingTensorHook):
"""Hook to log benchmark metric information. """Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given This hook is very similar as tf.train.LoggingTensorHook, which logs given
...@@ -68,7 +68,7 @@ class LoggingMetricHook(tf.train.LoggingTensorHook): ...@@ -68,7 +68,7 @@ class LoggingMetricHook(tf.train.LoggingTensorHook):
def begin(self): def begin(self):
super(LoggingMetricHook, self).begin() super(LoggingMetricHook, self).begin()
self._global_step_tensor = tf.train.get_global_step() self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None: if self._global_step_tensor is None:
raise RuntimeError( raise RuntimeError(
"Global step should be created to use LoggingMetricHook.") "Global step should be created to use LoggingMetricHook.")
......
...@@ -39,7 +39,7 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -39,7 +39,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
def tearDown(self): def tearDown(self):
super(LoggingMetricHookTest, self).tearDown() super(LoggingMetricHookTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
def test_illegal_args(self): def test_illegal_args(self):
with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"): with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
...@@ -55,15 +55,15 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -55,15 +55,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5) metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
def test_print_at_end_only(self): def test_print_at_end_only(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.train.get_or_create_global_step() tf.compat.v1.train.get_or_create_global_step()
t = tf.constant(42.0, name="foo") t = tf.constant(42.0, name="foo")
train_op = tf.constant(3) train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook( hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger) tensors=[t.name], at_end=True, metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
for _ in range(3): for _ in range(3):
mon_sess.run(train_op) mon_sess.run(train_op)
...@@ -88,8 +88,8 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -88,8 +88,8 @@ class LoggingMetricHookTest(tf.test.TestCase):
hook.begin() hook.begin()
def test_log_tensors(self): def test_log_tensors(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.train.get_or_create_global_step() tf.compat.v1.train.get_or_create_global_step()
t1 = tf.constant(42.0, name="foo") t1 = tf.constant(42.0, name="foo")
t2 = tf.constant(43.0, name="bar") t2 = tf.constant(43.0, name="bar")
train_op = tf.constant(3) train_op = tf.constant(3)
...@@ -97,7 +97,7 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -97,7 +97,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
tensors=[t1, t2], at_end=True, metric_logger=self._logger) tensors=[t1, t2], at_end=True, metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
for _ in range(3): for _ in range(3):
mon_sess.run(train_op) mon_sess.run(train_op)
...@@ -126,7 +126,7 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -126,7 +126,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger=self._logger) metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op) mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
for _ in range(3): for _ in range(3):
...@@ -153,15 +153,15 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -153,15 +153,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_print_every_n_steps(self): def test_print_every_n_steps(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.train.get_or_create_global_step() tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=False) self._validate_print_every_n_steps(sess, at_end=False)
# Verify proper reset. # Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=False) self._validate_print_every_n_steps(sess, at_end=False)
def test_print_every_n_steps_and_end(self): def test_print_every_n_steps_and_end(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.train.get_or_create_global_step() tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=True) self._validate_print_every_n_steps(sess, at_end=True)
# Verify proper reset. # Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=True) self._validate_print_every_n_steps(sess, at_end=True)
...@@ -175,7 +175,7 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -175,7 +175,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger=self._logger) metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op) mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
...@@ -199,15 +199,15 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -199,15 +199,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_print_every_n_secs(self): def test_print_every_n_secs(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.train.get_or_create_global_step() tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=False) self._validate_print_every_n_secs(sess, at_end=False)
# Verify proper reset. # Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=False) self._validate_print_every_n_secs(sess, at_end=False)
def test_print_every_n_secs_and_end(self): def test_print_every_n_secs_and_end(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.train.get_or_create_global_step() tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=True) self._validate_print_every_n_secs(sess, at_end=True)
# Verify proper reset. # Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=True) self._validate_print_every_n_secs(sess, at_end=True)
......
...@@ -94,7 +94,7 @@ def get_mlperf_log(): ...@@ -94,7 +94,7 @@ def get_mlperf_log():
version = pkg_resources.get_distribution("mlperf_compliance") version = pkg_resources.get_distribution("mlperf_compliance")
version = tuple(int(i) for i in version.version.split(".")) version = tuple(int(i) for i in version.version.split("."))
if version < _MIN_VERSION: if version < _MIN_VERSION:
tf.logging.warning( tf.compat.v1.logging.warning(
"mlperf_compliance is version {}, must be >= {}".format( "mlperf_compliance is version {}, must be >= {}".format(
".".join([str(i) for i in version]), ".".join([str(i) for i in version]),
".".join([str(i) for i in _MIN_VERSION]))) ".".join([str(i) for i in _MIN_VERSION])))
...@@ -187,6 +187,6 @@ def clear_system_caches(): ...@@ -187,6 +187,6 @@ def clear_system_caches():
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
with LOGGER(True): with LOGGER(True):
ncf_print(key=TAGS.RUN_START) ncf_print(key=TAGS.RUN_START)
...@@ -48,7 +48,7 @@ def past_stop_threshold(stop_threshold, eval_metric): ...@@ -48,7 +48,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number.") "must be a number.")
if eval_metric >= stop_threshold: if eval_metric >= stop_threshold:
tf.logging.info( tf.compat.v1.logging.info(
"Stop threshold of {} was passed with metric value {}.".format( "Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric)) stop_threshold, eval_metric))
return True return True
...@@ -87,7 +87,7 @@ def generate_synthetic_data( ...@@ -87,7 +87,7 @@ def generate_synthetic_data(
def apply_clean(flags_obj): def apply_clean(flags_obj):
if flags_obj.clean and tf.gfile.Exists(flags_obj.model_dir): if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.logging.info("--clean flag set. Removing existing model dir: {}".format( tf.compat.v1.logging.info("--clean flag set. Removing existing model dir: {}".format(
flags_obj.model_dir)) flags_obj.model_dir))
tf.gfile.DeleteRecursively(flags_obj.model_dir) tf.io.gfile.rmtree(flags_obj.model_dir)
...@@ -69,13 +69,13 @@ class SyntheticDataTest(tf.test.TestCase): ...@@ -69,13 +69,13 @@ class SyntheticDataTest(tf.test.TestCase):
"""Tests for generate_synthetic_data.""" """Tests for generate_synthetic_data."""
def test_generate_synethetic_data(self): def test_generate_synethetic_data(self):
input_element, label_element = model_helpers.generate_synthetic_data( input_element, label_element = tf.compat.v1.data.make_one_shot_iterator(
input_shape=tf.TensorShape([5]), model_helpers.generate_synthetic_data(input_shape=tf.TensorShape([5]),
input_value=123, input_value=123,
input_dtype=tf.float32, input_dtype=tf.float32,
label_shape=tf.TensorShape([]), label_shape=tf.TensorShape([]),
label_value=456, label_value=456,
label_dtype=tf.int32).make_one_shot_iterator().get_next() label_dtype=tf.int32)).get_next()
with self.test_session() as sess: with self.test_session() as sess:
for n in range(5): for n in range(5):
...@@ -89,7 +89,7 @@ class SyntheticDataTest(tf.test.TestCase): ...@@ -89,7 +89,7 @@ class SyntheticDataTest(tf.test.TestCase):
input_value=43.5, input_value=43.5,
input_dtype=tf.float32) input_dtype=tf.float32)
element = d.make_one_shot_iterator().get_next() element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
self.assertFalse(isinstance(element, tuple)) self.assertFalse(isinstance(element, tuple))
with self.test_session() as sess: with self.test_session() as sess:
...@@ -102,7 +102,7 @@ class SyntheticDataTest(tf.test.TestCase): ...@@ -102,7 +102,7 @@ class SyntheticDataTest(tf.test.TestCase):
'b': {'c': tf.TensorShape([3]), 'd': tf.TensorShape([])}}, 'b': {'c': tf.TensorShape([3]), 'd': tf.TensorShape([])}},
input_value=1.1) input_value=1.1)
element = d.make_one_shot_iterator().get_next() element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
self.assertIn('a', element) self.assertIn('a', element)
self.assertIn('b', element) self.assertIn('b', element)
self.assertEquals(len(element['b']), 2) self.assertEquals(len(element['b']), 2)
......
...@@ -170,12 +170,12 @@ class BaseTest(tf.test.TestCase): ...@@ -170,12 +170,12 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison. # Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString() graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph") expected_file = os.path.join(data_dir, "expected_graph")
with tf.gfile.Open(expected_file, "wb") as f: with tf.io.gfile.GFile(expected_file, "wb") as f:
f.write(graph_bytes) f.write(graph_bytes)
with graph.as_default(): with graph.as_default():
init = tf.global_variables_initializer() init = tf.compat.v1.global_variables_initializer()
saver = tf.train.Saver() saver = tf.compat.v1.train.Saver()
with self.test_session(graph=graph) as sess: with self.test_session(graph=graph) as sess:
sess.run(init) sess.run(init)
...@@ -191,11 +191,11 @@ class BaseTest(tf.test.TestCase): ...@@ -191,11 +191,11 @@ class BaseTest(tf.test.TestCase):
if correctness_function is not None: if correctness_function is not None:
results = correctness_function(*eval_results) results = correctness_function(*eval_results)
with tf.gfile.Open(os.path.join(data_dir, "results.json"), "w") as f: with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "w") as f:
json.dump(results, f) json.dump(results, f)
with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "w") as f: with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "w") as f:
json.dump([tf.VERSION, tf.GIT_VERSION], f) json.dump([tf.version.VERSION, tf.version.GIT_VERSION], f)
def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function): def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
"""Determine if a graph agrees with the reference data. """Determine if a graph agrees with the reference data.
...@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison. # Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString() graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph") expected_file = os.path.join(data_dir, "expected_graph")
with tf.gfile.Open(expected_file, "rb") as f: with tf.io.gfile.GFile(expected_file, "rb") as f:
expected_graph_bytes = f.read() expected_graph_bytes = f.read()
# The serialization is non-deterministic byte-for-byte. Instead there is # The serialization is non-deterministic byte-for-byte. Instead there is
# a utility which evaluates the semantics of the two graphs to test for # a utility which evaluates the semantics of the two graphs to test for
...@@ -228,19 +228,19 @@ class BaseTest(tf.test.TestCase): ...@@ -228,19 +228,19 @@ class BaseTest(tf.test.TestCase):
graph_bytes, expected_graph_bytes).decode("utf-8") graph_bytes, expected_graph_bytes).decode("utf-8")
with graph.as_default(): with graph.as_default():
init = tf.global_variables_initializer() init = tf.compat.v1.global_variables_initializer()
saver = tf.train.Saver() saver = tf.compat.v1.train.Saver()
with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "r") as f: with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "r") as f:
tf_version_reference, tf_git_version_reference = json.load(f) # pylint: disable=unpacking-non-sequence tf_version_reference, tf_git_version_reference = json.load(f) # pylint: disable=unpacking-non-sequence
tf_version_comparison = "" tf_version_comparison = ""
if tf.GIT_VERSION != tf_git_version_reference: if tf.version.GIT_VERSION != tf_git_version_reference:
tf_version_comparison = ( tf_version_comparison = (
"Test was built using: {} (git = {})\n" "Test was built using: {} (git = {})\n"
"Local TensorFlow version: {} (git = {})" "Local TensorFlow version: {} (git = {})"
.format(tf_version_reference, tf_git_version_reference, .format(tf_version_reference, tf_git_version_reference,
tf.VERSION, tf.GIT_VERSION) tf.version.VERSION, tf.version.GIT_VERSION)
) )
with self.test_session(graph=graph) as sess: with self.test_session(graph=graph) as sess:
...@@ -249,7 +249,7 @@ class BaseTest(tf.test.TestCase): ...@@ -249,7 +249,7 @@ class BaseTest(tf.test.TestCase):
saver.restore(sess=sess, save_path=os.path.join( saver.restore(sess=sess, save_path=os.path.join(
data_dir, self.ckpt_prefix)) data_dir, self.ckpt_prefix))
if differences: if differences:
tf.logging.warn( tf.compat.v1.logging.warn(
"The provided graph is different than expected:\n {}\n" "The provided graph is different than expected:\n {}\n"
"However the weights were still able to be loaded.\n{}".format( "However the weights were still able to be loaded.\n{}".format(
differences, tf_version_comparison) differences, tf_version_comparison)
...@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase): ...@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase):
eval_results = [op.eval() for op in ops_to_eval] eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None: if correctness_function is not None:
results = correctness_function(*eval_results) results = correctness_function(*eval_results)
with tf.gfile.Open(os.path.join(data_dir, "results.json"), "r") as f: with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "r") as f:
expected_results = json.load(f) expected_results = json.load(f)
self.assertAllClose(results, expected_results) self.assertAllClose(results, expected_results)
...@@ -298,7 +298,7 @@ class BaseTest(tf.test.TestCase): ...@@ -298,7 +298,7 @@ class BaseTest(tf.test.TestCase):
correctness_function=correctness_function correctness_function=correctness_function
) )
except: except:
tf.logging.error("Failed unittest {}".format(name)) tf.compat.v1.logging.error("Failed unittest {}".format(name))
raise raise
else: else:
self._construct_and_save_reference_files( self._construct_and_save_reference_files(
......
...@@ -63,12 +63,12 @@ class GoldenBaseTest(reference_data.BaseTest): ...@@ -63,12 +63,12 @@ class GoldenBaseTest(reference_data.BaseTest):
with g.as_default(): with g.as_default():
seed = self.name_to_seed(name) seed = self.name_to_seed(name)
seed = seed + 1 if bad_seed else seed seed = seed + 1 if bad_seed else seed
tf.set_random_seed(seed) tf.compat.v1.set_random_seed(seed)
tensor_name = "wrong_tensor" if wrong_name else "input_tensor" tensor_name = "wrong_tensor" if wrong_name else "input_tensor"
tensor_shape = (1, 2) if wrong_shape else (1, 1) tensor_shape = (1, 2) if wrong_shape else (1, 1)
input_tensor = tf.get_variable( input_tensor = tf.compat.v1.get_variable(
tensor_name, dtype=tf.float32, tensor_name, dtype=tf.float32,
initializer=tf.random_uniform(tensor_shape, maxval=1) initializer=tf.random.uniform(tensor_shape, maxval=1)
) )
def correctness_function(tensor_result): def correctness_function(tensor_result):
...@@ -86,13 +86,13 @@ class GoldenBaseTest(reference_data.BaseTest): ...@@ -86,13 +86,13 @@ class GoldenBaseTest(reference_data.BaseTest):
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
tf.set_random_seed(self.name_to_seed(name)) tf.compat.v1.set_random_seed(self.name_to_seed(name))
input_tensor = tf.get_variable( input_tensor = tf.compat.v1.get_variable(
"input_tensor", dtype=tf.float32, "input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((1, 2), maxval=1) initializer=tf.random.uniform((1, 2), maxval=1)
) )
layer = tf.layers.dense(inputs=input_tensor, units=4) layer = tf.compat.v1.layers.dense(inputs=input_tensor, units=4)
layer = tf.layers.dense(inputs=layer, units=1) layer = tf.compat.v1.layers.dense(inputs=layer, units=1)
self._save_or_test_ops( self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[layer], test=test, name=name, graph=g, ops_to_eval=[layer], test=test,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment