Unverified Commit 0270cac7 authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Add benchmark logger that does stream upload to bigquery. (#4210)

* Move the benchmark_uploader to new location.

* Update benchmark logger to streaming upload.

* Fix lint and unit test error.

* delint.

* Update the benchmark uploader test.

Skip the import of benchmark_uploader when bigquery is not installed.

* Merge the 2 classes of benchmark uploader into 1.

* Address review comments.

* delint.

* Execute bigquery upload in a separate thread.

* Change to use python six.moves for importing.

* Address review comments and delint.

* Address review comment.

Adding comment for potential performance impact for model on CPU.

* Fix random failure on py3.

* Fix the order of flag saver to avoid the randomness.

The test is broken when the benchmark_logger_type is set first, and
validated when the benchmark_log_dir is not set yet.
parent 80178fc6
......@@ -25,30 +25,19 @@ from __future__ import division
from __future__ import print_function
import json
import os
import sys
import uuid
from google.cloud import bigquery
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.logs import logger
class BigQueryUploader(object):
"""Upload the benchmark and metric info to BigQuery."""
"""Upload the benchmark and metric info from JSON input to BigQuery. """
def __init__(self, logging_dir, gcp_project=None, credentials=None):
def __init__(self, gcp_project=None, credentials=None):
"""Initialized BigQueryUploader with proper setting.
Args:
logging_dir: string, logging directory that contains the benchmark log.
gcp_project: string, the name of the GCP project that the log will be
uploaded to. The default project name will be detected from local
environment if no value is provided.
......@@ -58,11 +47,11 @@ class BigQueryUploader(object):
google.oauth2.service_account.Credentials to load credential from local
file for the case that the test is run out side of GCP.
"""
self._logging_dir = logging_dir
self._bq_client = bigquery.Client(
project=gcp_project, credentials=credentials)
def upload_benchmark_run(self, dataset_name, table_name, run_id):
def upload_benchmark_run_json(
self, dataset_name, table_name, run_id, run_json):
"""Upload benchmark run information to Bigquery.
Args:
......@@ -72,19 +61,13 @@ class BigQueryUploader(object):
the data will be uploaded.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format.
run_json: dict, the JSON data that contains the benchmark run info.
"""
expected_file = os.path.join(
self._logging_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
with tf.gfile.GFile(expected_file) as f:
benchmark_json = json.load(f)
benchmark_json["model_id"] = run_id
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
errors = self._bq_client.insert_rows_json(table_ref, [benchmark_json])
if errors:
tf.logging.error(
"Failed to upload benchmark info to bigquery: {}".format(errors))
def upload_metric(self, dataset_name, table_name, run_id):
run_json["model_id"] = run_id
self._upload_json(dataset_name, table_name, [run_json])
def upload_benchmark_metric_json(
self, dataset_name, table_name, run_id, metric_json_list):
"""Upload metric information to Bigquery.
Args:
......@@ -95,39 +78,57 @@ class BigQueryUploader(object):
benchmark_run table.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format. This should be the same as the benchmark run_id.
metric_json_list: list, a list of JSON object that record the metric info.
"""
for m in metric_json_list:
m["run_id"] = run_id
self._upload_json(dataset_name, table_name, metric_json_list)
def upload_benchmark_run_file(
self, dataset_name, table_name, run_id, run_json_file):
"""Upload benchmark run information to Bigquery from input json file.
Args:
dataset_name: string, the name of bigquery dataset where the data will be
uploaded.
table_name: string, the name of bigquery table under the dataset where
the data will be uploaded.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format.
run_json_file: string, the file path that contains the run JSON data.
"""
with tf.gfile.GFile(run_json_file) as f:
benchmark_json = json.load(f)
self.upload_benchmark_run_json(
dataset_name, table_name, run_id, benchmark_json)
def upload_metric_file(
self, dataset_name, table_name, run_id, metric_json_file):
"""Upload metric information to Bigquery from input json file.
Args:
dataset_name: string, the name of bigquery dataset where the data will be
uploaded.
table_name: string, the name of bigquery table under the dataset where
the metric data will be uploaded. This is different from the
benchmark_run table.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format. This should be the same as the benchmark run_id.
metric_json_file: string, the file path that contains the metric JSON
data.
"""
expected_file = os.path.join(
self._logging_dir, logger.METRIC_LOG_FILE_NAME)
with tf.gfile.GFile(expected_file) as f:
lines = f.readlines()
with tf.gfile.GFile(metric_json_file) as f:
metrics = []
for line in filter(lambda l: l.strip(), lines):
metric = json.loads(line)
metric["run_id"] = run_id
metrics.append(metric)
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
errors = self._bq_client.insert_rows_json(table_ref, metrics)
if errors:
tf.logging.error(
"Failed to upload benchmark info to bigquery: {}".format(errors))
def main(_):
if not flags.FLAGS.benchmark_log_dir:
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
sys.exit(1)
uploader = BigQueryUploader(
flags.FLAGS.benchmark_log_dir,
gcp_project=flags.FLAGS.gcp_project)
run_id = str(uuid.uuid4())
uploader.upload_benchmark_run(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id)
uploader.upload_metric(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id)
if __name__ == "__main__":
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
absl_app.run(main=main)
for line in f:
metrics.append(json.loads(line.strip()))
self.upload_benchmark_metric_json(
dataset_name, table_name, run_id, metrics)
def _upload_json(self, dataset_name, table_name, json_list):
# Find the unique table reference based on dataset and table name, so that
# the data can be inserted to it.
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
errors = self._bq_client.insert_rows_json(table_ref, json_list)
if errors:
tf.logging.error(
"Failed to upload benchmark info to bigquery: {}".format(errors))
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Binary to upload benchmark generated by BenchmarkLogger to remote repo.
This library require google cloud bigquery lib as dependency, which can be
installed with:
> pip install --upgrade google-cloud-bigquery
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import uuid
from absl import app as absl_app
from absl import flags
from official.benchmark import benchmark_uploader
from official.utils.flags import core as flags_core
from official.utils.logs import logger
def main(_):
if not flags.FLAGS.benchmark_log_dir:
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
sys.exit(1)
uploader = benchmark_uploader.BigQueryUploader(
gcp_project=flags.FLAGS.gcp_project)
run_id = str(uuid.uuid4())
run_json_file = os.path.join(
flags.FLAGS.benchmark_log_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
metric_json_file = os.path.join(
flags.FLAGS.benchmark_log_dir, logger.METRIC_LOG_FILE_NAME)
uploader.upload_benchmark_run_file(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id,
run_json_file)
uploader.upload_metric_file(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
metric_json_file)
if __name__ == "__main__":
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
absl_app.run(main=main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for benchmark_uploader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tempfile
import unittest
from mock import MagicMock
from mock import patch
import tensorflow as tf # pylint: disable=g-bad-import-order
try:
from google.cloud import bigquery
from official.benchmark import benchmark_uploader
except ImportError:
bigquery = None
benchmark_uploader = None
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.')
class BigQueryUploaderTest(tf.test.TestCase):
@patch.object(bigquery, 'Client')
def setUp(self, mock_bigquery):
self.mock_client = mock_bigquery.return_value
self.mock_dataset = MagicMock(name="dataset")
self.mock_table = MagicMock(name="table")
self.mock_client.dataset.return_value = self.mock_dataset
self.mock_dataset.table.return_value = self.mock_table
self.mock_client.insert_rows_json.return_value = []
self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
self.benchmark_uploader._bq_client = self.mock_client
self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
with open(os.path.join(self.log_dir, 'metric.log'), 'a') as f:
json.dump({'name': 'accuracy', 'value': 1.0}, f)
f.write("\n")
json.dump({'name': 'loss', 'value': 0.5}, f)
f.write("\n")
with open(os.path.join(self.log_dir, 'run.log'), 'w') as f:
json.dump({'model_name': 'value'}, f)
def tearDown(self):
tf.gfile.DeleteRecursively(self.get_temp_dir())
def test_upload_benchmark_run_json(self):
self.benchmark_uploader.upload_benchmark_run_json(
'dataset', 'table', 'run_id', {'model_name': 'value'})
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])
def test_upload_benchmark_metric_json(self):
metric_json_list = [
{'name': 'accuracy', 'value': 1.0},
{'name': 'loss', 'value': 0.5}
]
expected_params = [
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
]
self.benchmark_uploader.upload_benchmark_metric_json(
'dataset', 'table', 'run_id', metric_json_list)
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)
def test_upload_benchmark_run_file(self):
self.benchmark_uploader.upload_benchmark_run_file(
'dataset', 'table', 'run_id', os.path.join(self.log_dir, 'run.log'))
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])
def test_upload_metric_file(self):
self.benchmark_uploader.upload_metric_file(
'dataset', 'table', 'run_id',
os.path.join(self.log_dir, 'metric.log'))
expected_params = [
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
]
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)
if __name__ == '__main__':
tf.test.main()
......@@ -395,13 +395,12 @@ def resnet_main(
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
batch_size=flags_obj.batch_size,
benchmark_log_dir=flags_obj.benchmark_log_dir)
batch_size=flags_obj.batch_size)
def input_fn_train():
return input_function(
......
......@@ -36,6 +36,14 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
key_flags = []
flags.DEFINE_enum(
name="benchmark_logger_type", default="BaseBenchmarkLogger",
enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger",
"BenchmarkBigQueryLogger"],
help=help_wrap("The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work."))
if benchmark_log_dir:
flags.DEFINE_string(
name="benchmark_log_dir", short_name="bld", default=None,
......@@ -64,4 +72,14 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
help=help_wrap("The Bigquery table name where the benchmark metric "
"information will be uploaded."))
return key_flags
@flags.multi_flags_validator(
["benchmark_logger_type", "benchmark_log_dir"],
message="--benchmark_logger_type=BenchmarkFileLogger will require "
"--benchmark_log_dir being set")
def _check_benchmark_log_dir(flags_dict):
benchmark_logger_type = flags_dict["benchmark_logger_type"]
if benchmark_logger_type == "BenchmarkFileLogger":
return flags_dict["benchmark_log_dir"]
return True
return key_flags
......@@ -124,14 +124,12 @@ def get_examples_per_second_hook(every_n_steps=100,
warm_steps=warm_steps)
def get_logging_metric_hook(benchmark_log_dir=None,
tensors_to_log=None,
def get_logging_metric_hook(tensors_to_log=None,
every_n_secs=600,
**kwargs): # pylint: disable=unused-argument
"""Function to get LoggingMetricHook.
Args:
benchmark_log_dir: `string`, directory path to save the metric log.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every
......@@ -141,7 +139,6 @@ def get_logging_metric_hook(benchmark_log_dir=None,
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
logger.config_benchmark_logger(benchmark_log_dir)
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return metric_hook.LoggingMetricHook(
......
......@@ -60,8 +60,7 @@ class BaseTest(unittest.TestCase):
def test_get_logging_metric_hook(self):
test_hook_name = 'LoggingMetricHook'
self.validate_train_hook_name(test_hook_name, 'loggingmetrichook',
benchmark_log_dir='/tmp')
self.validate_train_hook_name(test_hook_name, 'loggingmetrichook')
if __name__ == '__main__':
tf.test.main()
......@@ -28,7 +28,10 @@ import multiprocessing
import numbers
import os
import threading
import uuid
from six.moves import _thread as thread
from absl import flags
import tensorflow as tf
from tensorflow.python.client import device_lib
......@@ -36,21 +39,39 @@ METRIC_LOG_FILE_NAME = "metric.log"
BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
FLAGS = flags.FLAGS
# Don't use it directly. Use get_benchmark_logger to access a logger.
_benchmark_logger = None
_logger_lock = threading.Lock()
def config_benchmark_logger(logging_dir):
def config_benchmark_logger(flag_obj=None):
"""Config the global benchmark logger"""
_logger_lock.acquire()
try:
global _benchmark_logger
if logging_dir:
_benchmark_logger = BenchmarkFileLogger(logging_dir)
else:
if not flag_obj:
flag_obj = FLAGS
if (not hasattr(flag_obj, 'benchmark_logger_type') or
flag_obj.benchmark_logger_type == 'BaseBenchmarkLogger'):
_benchmark_logger = BaseBenchmarkLogger()
elif flag_obj.benchmark_logger_type == 'BenchmarkFileLogger':
_benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir)
elif flag_obj.benchmark_logger_type == 'BenchmarkBigQueryLogger':
from official.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top
bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project)
_benchmark_logger = BenchmarkBigQueryLogger(
bigquery_uploader=bq_uploader,
bigquery_data_set=flag_obj.bigquery_data_set,
bigquery_run_table=flag_obj.bigquery_run_table,
bigquery_metric_table=flag_obj.bigquery_metric_table,
run_id=str(uuid.uuid4()))
else:
raise ValueError('Unrecognized benchmark_logger_type: %s',
flag_obj.benchmark_logger_type)
finally:
_logger_lock.release()
return _benchmark_logger
......@@ -58,8 +79,7 @@ def config_benchmark_logger(logging_dir):
def get_benchmark_logger():
if not _benchmark_logger:
config_benchmark_logger(None)
config_benchmark_logger()
return _benchmark_logger
......@@ -99,15 +119,9 @@ class BaseBenchmarkLogger(object):
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
if not isinstance(value, numbers.Number):
tf.logging.warning(
"Metric value to log should be a number. Got %s", type(value))
return
extras = _convert_to_json_dict(extras)
tf.logging.info("Benchmark metric: "
"Name %s, value %d, unit %s, global_step %d, extras %s",
name, value, unit, global_step, extras)
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
tf.logging.info("Benchmark metric: %s", metric)
def log_run_info(self, model_name, dataset_name, run_params):
tf.logging.info("Benchmark run: %s",
......@@ -137,28 +151,16 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
if not isinstance(value, numbers.Number):
tf.logging.warning(
"Metric value to log should be a number. Got %s", type(value))
return
extras = _convert_to_json_dict(extras)
with tf.gfile.GFile(
os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a") as f:
metric = {
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"timestamp": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN),
"extras": extras}
try:
json.dump(metric, f)
f.write("\n")
except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump metric to log file: "
"name %s, value %s, error %s", name, value, e)
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
with tf.gfile.GFile(
os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a") as f:
try:
json.dump(metric, f)
f.write("\n")
except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump metric to log file: "
"name %s, value %s, error %s", name, value, e)
def log_run_info(self, model_name, dataset_name, run_params):
"""Collect most of the TF runtime information for the local env.
......@@ -183,6 +185,68 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
e)
class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
"""Class to log the benchmark information to BigQuery data store."""
def __init__(self,
bigquery_uploader,
bigquery_data_set,
bigquery_run_table,
bigquery_metric_table,
run_id):
super(BenchmarkBigQueryLogger, self).__init__()
self._bigquery_uploader = bigquery_uploader
self._bigquery_data_set = bigquery_data_set
self._bigquery_run_table = bigquery_run_table
self._bigquery_metric_table = bigquery_metric_table
self._run_id = run_id
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
"""Log the benchmark metric information to bigquery.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
# Starting new thread for bigquery upload in case it might take long time
# and impact the benchmark and performance measurement. Starting a new
# thread might have potential performance impact for model that run on
# CPU.
thread.start_new_thread(
self._bigquery_uploader.upload_benchmark_metric_json,
(self._bigquery_data_set,
self._bigquery_metric_table,
self._run_id,
[metric]))
def log_run_info(self, model_name, dataset_name, run_params):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params)
# Starting new thread for bigquery upload in case it might take long time
# and impact the benchmark and performance measurement. Starting a new
# thread might have potential performance impact for model that run on CPU.
thread.start_new_thread(
self._bigquery_uploader.upload_benchmark_run_json,
(self._bigquery_data_set,
self._bigquery_run_table,
self._run_id,
run_info))
def _gather_run_info(model_name, dataset_name, run_params):
"""Collect the benchmark run information for the local environment."""
run_info = {
......@@ -200,6 +264,25 @@ def _gather_run_info(model_name, dataset_name, run_params):
return run_info
def _process_metric_to_json(
name, value, unit=None, global_step=None, extras=None):
"""Validate the metric data and generate JSON for insert."""
if not isinstance(value, numbers.Number):
tf.logging.warning(
"Metric value to log should be a number. Got %s", type(value))
return None
extras = _convert_to_json_dict(extras)
return {
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"timestamp": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN),
"extras": extras}
def _collect_tensorflow_info(run_info):
run_info["tensorflow_version"] = {
"version": tf.VERSION, "git_hash": tf.GIT_VERSION}
......
......@@ -22,28 +22,55 @@ from __future__ import print_function
import json
import os
import tempfile
import time
import unittest
import mock
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
try:
from google.cloud import bigquery
except ImportError:
bigquery = None
from official.utils.flags import core as flags_core
from official.utils.logs import logger
class BenchmarkLoggerTest(tf.test.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BenchmarkLoggerTest, cls).setUpClass()
flags_core.define_benchmark()
def test_get_default_benchmark_logger(self):
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger)
with flagsaver.flagsaver(benchmark_logger_type='foo'):
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger)
def test_config_base_benchmark_logger(self):
logger.config_benchmark_logger("")
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger)
with flagsaver.flagsaver(benchmark_logger_type='BaseBenchmarkLogger'):
logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger)
def test_config_benchmark_file_logger(self):
logger.config_benchmark_logger("/tmp/abc")
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkFileLogger)
# Set the benchmark_log_dir first since the benchmark_logger_type will need
# the value to be set when it does the validation.
with flagsaver.flagsaver(benchmark_log_dir='/tmp'):
with flagsaver.flagsaver(benchmark_logger_type='BenchmarkFileLogger'):
logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkFileLogger)
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.')
def test_config_benchmark_bigquery_logger(self):
with flagsaver.flagsaver(benchmark_logger_type='BenchmarkBigQueryLogger'):
logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkBigQueryLogger)
class BaseBenchmarkLoggerTest(tf.test.TestCase):
......@@ -233,5 +260,46 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
self.assertIsNotNone(run_info["machine_config"]["memory_total"])
self.assertIsNotNone(run_info["machine_config"]["memory_available"])
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.')
class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
def setUp(self):
super(BenchmarkBigQueryLoggerTest, self).setUp()
# Avoid pulling extra env vars from test environment which affects the test
# result, eg. Kokoro test has a TF_PKG env which affect the test case
# test_collect_tensorflow_environment_variables()
self.original_environ = dict(os.environ)
os.environ.clear()
self.mock_bq_uploader = mock.MagicMock()
self.logger = logger.BenchmarkBigQueryLogger(
self.mock_bq_uploader, "dataset", "run_table", "metric_table",
"run_id")
def tearDown(self):
super(BenchmarkBigQueryLoggerTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
os.environ.clear()
os.environ.update(self.original_environ)
def test_log_metric(self):
self.logger.log_metric(
"accuracy", 0.999, global_step=1e4, extras={"name": "value"})
expected_metric_json = [{
"name": "accuracy",
"value": 0.999,
"unit": None,
"global_step": 1e4,
"timestamp": mock.ANY,
"extras": [{"name": "name", "value": "value"}]
}]
# log_metric will call upload_benchmark_metric_json in a separate thread.
# Give it some grace period for the new thread before assert.
time.sleep(1)
self.mock_bq_uploader.upload_benchmark_metric_json.assert_called_once_with(
"dataset", "metric_table", "run_id", expected_metric_json)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment