Commit 3043566d authored by ayushmankumar7's avatar ayushmankumar7
Browse files

tf.compat.v1.logging implemented with absl

parent 1e2ceffd
......@@ -28,6 +28,7 @@ import unittest
import mock
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
from absl import logging
try:
from google.cloud import bigquery
......@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger
with logger.benchmark_context(None):
tf.compat.v1.logging.info("start benchmarking")
logging.info("start benchmarking")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS)
@mock.patch("official.utils.logs.logger.config_benchmark_logger")
......@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def setUp(self):
super(BaseBenchmarkLoggerTest, self).setUp()
self._actual_log = tf.compat.v1.logging.info
self._actual_log = logging.info
self.logged_message = None
def mock_log(*args, **kwargs):
self.logged_message = args
self._actual_log(*args, **kwargs)
tf.compat.v1.logging.info = mock_log
logging.info = mock_log
def tearDown(self):
super(BaseBenchmarkLoggerTest, self).tearDown()
tf.compat.v1.logging.info = self._actual_log
logging.info = self._actual_log
def test_log_metric(self):
log = logger.BaseBenchmarkLogger()
......
......@@ -31,6 +31,7 @@ import re
import subprocess
import sys
import typing
from absl import logging
import tensorflow as tf
......@@ -94,7 +95,7 @@ def get_mlperf_log():
version = pkg_resources.get_distribution("mlperf_compliance")
version = tuple(int(i) for i in version.version.split("."))
if version < _MIN_VERSION:
tf.compat.v1.logging.warning(
logging.warning(
"mlperf_compliance is version {}, must be >= {}".format(
".".join([str(i) for i in version]),
".".join([str(i) for i in _MIN_VERSION])))
......@@ -187,6 +188,6 @@ def clear_system_caches():
if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
logging.set_verbosity(logging.INFO)
with LOGGER(True):
ncf_print(key=TAGS.RUN_START)
......@@ -23,6 +23,7 @@ import os
import random
import string
import tensorflow.compat.v2 as tf
from absl import logging
from official.utils.misc import tpu_lib
......@@ -252,7 +253,7 @@ class SyntheticIterator(object):
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.')
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
......
......@@ -22,6 +22,7 @@ import numbers
import tensorflow as tf
from tensorflow.python.util import nest
from absl import logging
def past_stop_threshold(stop_threshold, eval_metric):
......@@ -48,7 +49,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number.")
if eval_metric >= stop_threshold:
tf.compat.v1.logging.info(
logging.info(
"Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric))
return True
......@@ -88,6 +89,6 @@ def generate_synthetic_data(
def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:"
logging.info("--clean flag set. Removing existing model dir:"
" {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir)
......@@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
import os
from absl import logging
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
......@@ -64,7 +64,7 @@ class PerfZeroBenchmark(tf.test.Benchmark):
def _setup(self):
"""Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
logging.set_verbosity(logging.INFO)
if PerfZeroBenchmark.local_flags is None:
for flag_method in self.flag_methods:
flag_method()
......
......@@ -21,6 +21,7 @@ from mock import Mock
import numpy as np
import tensorflow as tf
from absl import logging
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.vision.image_classification import common
......@@ -106,5 +107,5 @@ class KerasCommonTests(tf.test.TestCase):
return eval_output
if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logging.set_verbosity(logging.ERROR)
googletest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment