Commit bd488858 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8302 from ayushmankumar7:absl

PiperOrigin-RevId: 302043775
parents 2416dd9c 55bf4b80
...@@ -28,6 +28,7 @@ import unittest ...@@ -28,6 +28,7 @@ import unittest
import mock import mock
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from absl import logging
try: try:
from google.cloud import bigquery from google.cloud import bigquery
...@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase): ...@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger = mock.MagicMock() mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger mock_config_benchmark_logger.return_value = mock_logger
with logger.benchmark_context(None): with logger.benchmark_context(None):
tf.compat.v1.logging.info("start benchmarking") logging.info("start benchmarking")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS) mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS)
@mock.patch("official.utils.logs.logger.config_benchmark_logger") @mock.patch("official.utils.logs.logger.config_benchmark_logger")
...@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase): ...@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(BaseBenchmarkLoggerTest, self).setUp() super(BaseBenchmarkLoggerTest, self).setUp()
self._actual_log = tf.compat.v1.logging.info self._actual_log = logging.info
self.logged_message = None self.logged_message = None
def mock_log(*args, **kwargs): def mock_log(*args, **kwargs):
self.logged_message = args self.logged_message = args
self._actual_log(*args, **kwargs) self._actual_log(*args, **kwargs)
tf.compat.v1.logging.info = mock_log logging.info = mock_log
def tearDown(self): def tearDown(self):
super(BaseBenchmarkLoggerTest, self).tearDown() super(BaseBenchmarkLoggerTest, self).tearDown()
tf.compat.v1.logging.info = self._actual_log logging.info = self._actual_log
def test_log_metric(self): def test_log_metric(self):
log = logger.BaseBenchmarkLogger() log = logger.BaseBenchmarkLogger()
......
...@@ -31,8 +31,9 @@ import re ...@@ -31,8 +31,9 @@ import re
import subprocess import subprocess
import sys import sys
import typing import typing
from absl import logging
# pylint:disable=logging-format-interpolation
import tensorflow as tf
_MIN_VERSION = (0, 0, 10) _MIN_VERSION = (0, 0, 10)
_STACK_OFFSET = 2 _STACK_OFFSET = 2
...@@ -94,10 +95,9 @@ def get_mlperf_log(): ...@@ -94,10 +95,9 @@ def get_mlperf_log():
version = pkg_resources.get_distribution("mlperf_compliance") version = pkg_resources.get_distribution("mlperf_compliance")
version = tuple(int(i) for i in version.version.split(".")) version = tuple(int(i) for i in version.version.split("."))
if version < _MIN_VERSION: if version < _MIN_VERSION:
tf.compat.v1.logging.warning( logging.warning("mlperf_compliance is version {}, must be >= {}".format(
"mlperf_compliance is version {}, must be >= {}".format( ".".join([str(i) for i in version]),
".".join([str(i) for i in version]), ".".join([str(i) for i in _MIN_VERSION])))
".".join([str(i) for i in _MIN_VERSION])))
raise ImportError raise ImportError
return mlperf_compliance.mlperf_log return mlperf_compliance.mlperf_log
...@@ -187,6 +187,6 @@ def clear_system_caches(): ...@@ -187,6 +187,6 @@ def clear_system_caches():
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
with LOGGER(True): with LOGGER(True):
ncf_print(key=TAGS.RUN_START) ncf_print(key=TAGS.RUN_START)
...@@ -22,6 +22,8 @@ import json ...@@ -22,6 +22,8 @@ import json
import os import os
import random import random
import string import string
from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
...@@ -252,7 +254,7 @@ class SyntheticIterator(object): ...@@ -252,7 +254,7 @@ class SyntheticIterator(object):
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset): def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.') logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync) return SyntheticDataset(dataset, self.num_replicas_in_sync)
......
...@@ -20,8 +20,11 @@ from __future__ import print_function ...@@ -20,8 +20,11 @@ from __future__ import print_function
import numbers import numbers
from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import nest from tensorflow.python.util import nest
# pylint:disable=logging-format-interpolation
def past_stop_threshold(stop_threshold, eval_metric): def past_stop_threshold(stop_threshold, eval_metric):
...@@ -48,9 +51,8 @@ def past_stop_threshold(stop_threshold, eval_metric): ...@@ -48,9 +51,8 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number.") "must be a number.")
if eval_metric >= stop_threshold: if eval_metric >= stop_threshold:
tf.compat.v1.logging.info( logging.info("Stop threshold of {} was passed with metric value {}.".format(
"Stop threshold of {} was passed with metric value {}.".format( stop_threshold, eval_metric))
stop_threshold, eval_metric))
return True return True
return False return False
...@@ -88,6 +90,6 @@ def generate_synthetic_data( ...@@ -88,6 +90,6 @@ def generate_synthetic_data(
def apply_clean(flags_obj): def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir): if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:" logging.info("--clean flag set. Removing existing model dir:"
" {}".format(flags_obj.model_dir)) " {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir) tf.io.gfile.rmtree(flags_obj.model_dir)
...@@ -20,8 +20,9 @@ from __future__ import print_function ...@@ -20,8 +20,9 @@ from __future__ import print_function
import os import os
from absl import flags from absl import flags
from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -75,7 +76,7 @@ class PerfZeroBenchmark(tf.test.Benchmark): ...@@ -75,7 +76,7 @@ class PerfZeroBenchmark(tf.test.Benchmark):
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
if PerfZeroBenchmark.local_flags is None: if PerfZeroBenchmark.local_flags is None:
for flag_method in self.flag_methods: for flag_method in self.flag_methods:
flag_method() flag_method()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment