Commit 3043566d authored by ayushmankumar7's avatar ayushmankumar7
Browse files

tf.compat.v1.logging implemented with absl

parent 1e2ceffd
...@@ -28,6 +28,7 @@ import unittest ...@@ -28,6 +28,7 @@ import unittest
import mock import mock
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from absl import logging
try: try:
from google.cloud import bigquery from google.cloud import bigquery
...@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase): ...@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger = mock.MagicMock() mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger mock_config_benchmark_logger.return_value = mock_logger
with logger.benchmark_context(None): with logger.benchmark_context(None):
tf.compat.v1.logging.info("start benchmarking") logging.info("start benchmarking")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS) mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS)
@mock.patch("official.utils.logs.logger.config_benchmark_logger") @mock.patch("official.utils.logs.logger.config_benchmark_logger")
...@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase): ...@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(BaseBenchmarkLoggerTest, self).setUp() super(BaseBenchmarkLoggerTest, self).setUp()
self._actual_log = tf.compat.v1.logging.info self._actual_log = logging.info
self.logged_message = None self.logged_message = None
def mock_log(*args, **kwargs): def mock_log(*args, **kwargs):
self.logged_message = args self.logged_message = args
self._actual_log(*args, **kwargs) self._actual_log(*args, **kwargs)
tf.compat.v1.logging.info = mock_log logging.info = mock_log
def tearDown(self): def tearDown(self):
super(BaseBenchmarkLoggerTest, self).tearDown() super(BaseBenchmarkLoggerTest, self).tearDown()
tf.compat.v1.logging.info = self._actual_log logging.info = self._actual_log
def test_log_metric(self): def test_log_metric(self):
log = logger.BaseBenchmarkLogger() log = logger.BaseBenchmarkLogger()
......
...@@ -31,6 +31,7 @@ import re ...@@ -31,6 +31,7 @@ import re
import subprocess import subprocess
import sys import sys
import typing import typing
from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -94,7 +95,7 @@ def get_mlperf_log(): ...@@ -94,7 +95,7 @@ def get_mlperf_log():
version = pkg_resources.get_distribution("mlperf_compliance") version = pkg_resources.get_distribution("mlperf_compliance")
version = tuple(int(i) for i in version.version.split(".")) version = tuple(int(i) for i in version.version.split("."))
if version < _MIN_VERSION: if version < _MIN_VERSION:
tf.compat.v1.logging.warning( logging.warning(
"mlperf_compliance is version {}, must be >= {}".format( "mlperf_compliance is version {}, must be >= {}".format(
".".join([str(i) for i in version]), ".".join([str(i) for i in version]),
".".join([str(i) for i in _MIN_VERSION]))) ".".join([str(i) for i in _MIN_VERSION])))
...@@ -187,6 +188,6 @@ def clear_system_caches(): ...@@ -187,6 +188,6 @@ def clear_system_caches():
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
with LOGGER(True): with LOGGER(True):
ncf_print(key=TAGS.RUN_START) ncf_print(key=TAGS.RUN_START)
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import random import random
import string import string
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from absl import logging
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
...@@ -252,7 +253,7 @@ class SyntheticIterator(object): ...@@ -252,7 +253,7 @@ class SyntheticIterator(object):
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset): def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.') logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync) return SyntheticDataset(dataset, self.num_replicas_in_sync)
......
...@@ -22,6 +22,7 @@ import numbers ...@@ -22,6 +22,7 @@ import numbers
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import nest from tensorflow.python.util import nest
from absl import logging
def past_stop_threshold(stop_threshold, eval_metric): def past_stop_threshold(stop_threshold, eval_metric):
...@@ -48,7 +49,7 @@ def past_stop_threshold(stop_threshold, eval_metric): ...@@ -48,7 +49,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number.") "must be a number.")
if eval_metric >= stop_threshold: if eval_metric >= stop_threshold:
tf.compat.v1.logging.info( logging.info(
"Stop threshold of {} was passed with metric value {}.".format( "Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric)) stop_threshold, eval_metric))
return True return True
...@@ -88,6 +89,6 @@ def generate_synthetic_data( ...@@ -88,6 +89,6 @@ def generate_synthetic_data(
def apply_clean(flags_obj): def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir): if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:" logging.info("--clean flag set. Removing existing model dir:"
" {}".format(flags_obj.model_dir)) " {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir) tf.io.gfile.rmtree(flags_obj.model_dir)
...@@ -18,7 +18,7 @@ from __future__ import division ...@@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from absl import logging
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
...@@ -64,7 +64,7 @@ class PerfZeroBenchmark(tf.test.Benchmark): ...@@ -64,7 +64,7 @@ class PerfZeroBenchmark(tf.test.Benchmark):
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
if PerfZeroBenchmark.local_flags is None: if PerfZeroBenchmark.local_flags is None:
for flag_method in self.flag_methods: for flag_method in self.flag_methods:
flag_method() flag_method()
......
...@@ -21,6 +21,7 @@ from mock import Mock ...@@ -21,6 +21,7 @@ from mock import Mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from absl import logging
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import common from official.vision.image_classification import common
...@@ -106,5 +107,5 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -106,5 +107,5 @@ class KerasCommonTests(tf.test.TestCase):
return eval_output return eval_output
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logging.set_verbosity(logging.ERROR)
googletest.main() googletest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment