Unverified Commit 47c5642e authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Record the status for a benchmark run. (#4402)

* Update benchmark logger to update the run status.

This is important for streaming upload to bigquery so that the
dashboard can ignore the 'running' benchmark at the moment since
its not finished yet.

* Move the run status into a separate table.

Also update the run status in the benchmark uploader and
BigqueryBenchmarkLogger.

* Insert instead of update for the benchmark status for file logger.

* Address review comments.

Update the logger to have benchmark context, which will update the
run status accordingly.

* Fix broken tests.

* Move the benchmark logger context to main function.

* Fix tests.

* Update the rest of the models to use the context in main.

* Delint.
parent d530ac54
...@@ -27,6 +27,7 @@ from __future__ import print_function ...@@ -27,6 +27,7 @@ from __future__ import print_function
import json import json
from google.cloud import bigquery from google.cloud import bigquery
from google.cloud import exceptions
import tensorflow as tf import tensorflow as tf
...@@ -132,3 +133,25 @@ class BigQueryUploader(object): ...@@ -132,3 +133,25 @@ class BigQueryUploader(object):
if errors: if errors:
tf.logging.error( tf.logging.error(
"Failed to upload benchmark info to bigquery: {}".format(errors)) "Failed to upload benchmark info to bigquery: {}".format(errors))
def insert_run_status(self, dataset_name, table_name, run_id, run_status):
"""Insert the run status in to Bigquery run status table."""
query = ("INSERT {ds}.{tb} "
"(run_id, status) "
"VALUES('{rid}', '{status}')").format(
ds=dataset_name, tb=table_name, rid=run_id, status=run_status)
try:
self._bq_client.query(query=query).result()
except exceptions.GoogleCloudError as e:
tf.logging.error("Failed to insert run status: %s", e)
def update_run_status(self, dataset_name, table_name, run_id, run_status):
"""Update the run status in in Bigquery run status table."""
query = ("UPDATE {ds}.{tb} "
"SET status = '{status}' "
"WHERE run_id = '{rid}'").format(
ds=dataset_name, tb=table_name, status=run_status, rid=run_id)
try:
self._bq_client.query(query=query).result()
except exceptions.GoogleCloudError as e:
tf.logging.error("Failed to update run status: %s", e)
...@@ -54,6 +54,10 @@ def main(_): ...@@ -54,6 +54,10 @@ def main(_):
uploader.upload_metric_file( uploader.upload_metric_file(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id, flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
metric_json_file) metric_json_file)
# Assume the run finished successfully before user invoke the upload script.
uploader.insert_run_status(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_status_table,
run_id, logger.RUN_STATUS_SUCCESS)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -36,10 +36,10 @@ except ImportError: ...@@ -36,10 +36,10 @@ except ImportError:
benchmark_uploader = None benchmark_uploader = None
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.') @unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
class BigQueryUploaderTest(tf.test.TestCase): class BigQueryUploaderTest(tf.test.TestCase):
@patch.object(bigquery, 'Client') @patch.object(bigquery, "Client")
def setUp(self, mock_bigquery): def setUp(self, mock_bigquery):
self.mock_client = mock_bigquery.return_value self.mock_client = mock_bigquery.return_value
self.mock_dataset = MagicMock(name="dataset") self.mock_dataset = MagicMock(name="dataset")
...@@ -52,56 +52,72 @@ class BigQueryUploaderTest(tf.test.TestCase): ...@@ -52,56 +52,72 @@ class BigQueryUploaderTest(tf.test.TestCase):
self.benchmark_uploader._bq_client = self.mock_client self.benchmark_uploader._bq_client = self.mock_client
self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
with open(os.path.join(self.log_dir, 'metric.log'), 'a') as f: with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
json.dump({'name': 'accuracy', 'value': 1.0}, f) json.dump({"name": "accuracy", "value": 1.0}, f)
f.write("\n") f.write("\n")
json.dump({'name': 'loss', 'value': 0.5}, f) json.dump({"name": "loss", "value": 0.5}, f)
f.write("\n") f.write("\n")
with open(os.path.join(self.log_dir, 'run.log'), 'w') as f: with open(os.path.join(self.log_dir, "run.log"), "w") as f:
json.dump({'model_name': 'value'}, f) json.dump({"model_name": "value"}, f)
def tearDown(self): def tearDown(self):
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
def test_upload_benchmark_run_json(self): def test_upload_benchmark_run_json(self):
self.benchmark_uploader.upload_benchmark_run_json( self.benchmark_uploader.upload_benchmark_run_json(
'dataset', 'table', 'run_id', {'model_name': 'value'}) "dataset", "table", "run_id", {"model_name": "value"})
self.mock_client.insert_rows_json.assert_called_once_with( self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}]) self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
def test_upload_benchmark_metric_json(self): def test_upload_benchmark_metric_json(self):
metric_json_list = [ metric_json_list = [
{'name': 'accuracy', 'value': 1.0}, {"name": "accuracy", "value": 1.0},
{'name': 'loss', 'value': 0.5} {"name": "loss", "value": 0.5}
] ]
expected_params = [ expected_params = [
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0}, {"run_id": "run_id", "name": "accuracy", "value": 1.0},
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5} {"run_id": "run_id", "name": "loss", "value": 0.5}
] ]
self.benchmark_uploader.upload_benchmark_metric_json( self.benchmark_uploader.upload_benchmark_metric_json(
'dataset', 'table', 'run_id', metric_json_list) "dataset", "table", "run_id", metric_json_list)
self.mock_client.insert_rows_json.assert_called_once_with( self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params) self.mock_table, expected_params)
def test_upload_benchmark_run_file(self): def test_upload_benchmark_run_file(self):
self.benchmark_uploader.upload_benchmark_run_file( self.benchmark_uploader.upload_benchmark_run_file(
'dataset', 'table', 'run_id', os.path.join(self.log_dir, 'run.log')) "dataset", "table", "run_id", os.path.join(self.log_dir, "run.log"))
self.mock_client.insert_rows_json.assert_called_once_with( self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}]) self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
def test_upload_metric_file(self): def test_upload_metric_file(self):
self.benchmark_uploader.upload_metric_file( self.benchmark_uploader.upload_metric_file(
'dataset', 'table', 'run_id', "dataset", "table", "run_id",
os.path.join(self.log_dir, 'metric.log')) os.path.join(self.log_dir, "metric.log"))
expected_params = [ expected_params = [
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0}, {"run_id": "run_id", "name": "accuracy", "value": 1.0},
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5} {"run_id": "run_id", "name": "loss", "value": 0.5}
] ]
self.mock_client.insert_rows_json.assert_called_once_with( self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params) self.mock_table, expected_params)
def test_insert_run_status(self):
self.benchmark_uploader.insert_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("INSERT dataset.table "
"(run_id, status) "
"VALUES('run_id', 'status')")
self.mock_client.query.assert_called_once_with(query=expected_query)
if __name__ == '__main__': def test_update_run_status(self):
self.benchmark_uploader.update_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("UPDATE dataset.table "
"SET status = 'status' "
"WHERE run_id = 'run_id'")
self.mock_client.query.assert_called_once_with(query=expected_query)
if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -5,12 +5,6 @@ ...@@ -5,12 +5,6 @@
"name": "model_id", "name": "model_id",
"type": "STRING" "type": "STRING"
}, },
{
"description": "The status of the run for the benchmark. Eg, running, failed, success",
"mode": "NULLABLE",
"name": "status",
"type": "STRING"
},
{ {
"description": "The name of the model, E.g ResNet50, LeNet-5 etc.", "description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
"mode": "REQUIRED", "mode": "REQUIRED",
......
[
{
"description": "The UUID of the run for the benchmark.",
"mode": "REQUIRED",
"name": "run_id",
"type": "STRING"
},
{
"description": "The status of the run for the benchmark. Eg, running, failed, success",
"mode": "REQUIRED",
"name": "status",
"type": "STRING"
}
]
\ No newline at end of file
...@@ -198,6 +198,12 @@ def per_device_batch_size(batch_size, num_gpus): ...@@ -198,6 +198,12 @@ def per_device_batch_size(batch_size, num_gpus):
def main(_): def main(_):
with logger.benchmark_context(FLAGS):
run_ncf(FLAGS)
def run_ncf(_):
"""Run NCF training and eval loop."""
# Data preprocessing # Data preprocessing
# The file name of training and test dataset # The file name of training and test dataset
train_fname = os.path.join( train_fname = os.path.join(
...@@ -237,7 +243,7 @@ def main(_): ...@@ -237,7 +243,7 @@ def main(_):
"hr_threshold": FLAGS.hr_threshold, "hr_threshold": FLAGS.hr_threshold,
"train_epochs": FLAGS.train_epochs, "train_epochs": FLAGS.train_epochs,
} }
benchmark_logger = logger.config_benchmark_logger(FLAGS) benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info( benchmark_logger.log_run_info(
model_name="recommendation", model_name="recommendation",
dataset_name=FLAGS.dataset, dataset_name=FLAGS.dataset,
......
...@@ -25,6 +25,7 @@ from absl import flags ...@@ -25,6 +25,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -236,14 +237,14 @@ def run_cifar(flags_obj): ...@@ -236,14 +237,14 @@ def run_cifar(flags_obj):
""" """
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn) or input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME, flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
def main(_): def main(_):
run_cifar(flags.FLAGS) with logger.benchmark_context(flags.FLAGS):
run_cifar(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -25,6 +25,7 @@ from absl import flags ...@@ -25,6 +25,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.resnet import imagenet_preprocessing from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -321,7 +322,8 @@ def run_imagenet(flags_obj): ...@@ -321,7 +322,8 @@ def run_imagenet(flags_obj):
def main(_): def main(_):
run_imagenet(flags.FLAGS) with logger.benchmark_context(flags.FLAGS):
run_imagenet(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -395,7 +395,7 @@ def resnet_main( ...@@ -395,7 +395,7 @@ def resnet_main(
'synthetic_data': flags_obj.use_synthetic_data, 'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs, 'train_epochs': flags_obj.train_epochs,
} }
benchmark_logger = logger.config_benchmark_logger(flags_obj) benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('resnet', dataset_name, run_params) benchmark_logger.log_run_info('resnet', dataset_name, run_params)
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
...@@ -415,7 +415,6 @@ def resnet_main( ...@@ -415,7 +415,6 @@ def resnet_main(
batch_size=per_device_batch_size( batch_size=per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1) num_epochs=1)
total_training_cycle = (flags_obj.train_epochs // total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals) flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle): for cycle_index in range(total_training_cycle):
......
...@@ -436,7 +436,7 @@ def run_transformer(flags_obj): ...@@ -436,7 +436,7 @@ def run_transformer(flags_obj):
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=params.batch_size # for ExamplesPerSecondHook batch_size=params.batch_size # for ExamplesPerSecondHook
) )
benchmark_logger = logger.config_benchmark_logger(flags_obj) benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info( benchmark_logger.log_run_info(
model_name="transformer", model_name="transformer",
dataset_name="wmt_translate_ende", dataset_name="wmt_translate_ende",
...@@ -445,6 +445,7 @@ def run_transformer(flags_obj): ...@@ -445,6 +445,7 @@ def run_transformer(flags_obj):
# Train and evaluate transformer model # Train and evaluate transformer model
estimator = tf.estimator.Estimator( estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params) model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)
train_schedule( train_schedule(
estimator=estimator, estimator=estimator,
# Training arguments # Training arguments
...@@ -461,7 +462,8 @@ def run_transformer(flags_obj): ...@@ -461,7 +462,8 @@ def run_transformer(flags_obj):
def main(_): def main(_):
run_transformer(flags.FLAGS) with logger.benchmark_context(flags.FLAGS):
run_transformer(flags.FLAGS)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -66,6 +66,12 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): ...@@ -66,6 +66,12 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
help=help_wrap("The Bigquery table name where the benchmark run " help=help_wrap("The Bigquery table name where the benchmark run "
"information will be uploaded.")) "information will be uploaded."))
flags.DEFINE_string(
name="bigquery_run_status_table", short_name="brst",
default="benchmark_run_status",
help=help_wrap("The Bigquery table name where the benchmark run "
"status information will be uploaded."))
flags.DEFINE_string( flags.DEFINE_string(
name="bigquery_metric_table", short_name="bmt", name="bigquery_metric_table", short_name="bmt",
default="benchmark_metric", default="benchmark_metric",
......
...@@ -22,6 +22,7 @@ from __future__ import absolute_import ...@@ -22,6 +22,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib
import datetime import datetime
import json import json
import multiprocessing import multiprocessing
...@@ -38,6 +39,9 @@ from tensorflow.python.client import device_lib ...@@ -38,6 +39,9 @@ from tensorflow.python.client import device_lib
METRIC_LOG_FILE_NAME = "metric.log" METRIC_LOG_FILE_NAME = "metric.log"
BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log" BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ" _DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
RUN_STATUS_SUCCESS = "success"
RUN_STATUS_FAILURE = "failure"
RUN_STATUS_RUNNING = "running"
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -66,6 +70,7 @@ def config_benchmark_logger(flag_obj=None): ...@@ -66,6 +70,7 @@ def config_benchmark_logger(flag_obj=None):
bigquery_uploader=bq_uploader, bigquery_uploader=bq_uploader,
bigquery_data_set=flag_obj.bigquery_data_set, bigquery_data_set=flag_obj.bigquery_data_set,
bigquery_run_table=flag_obj.bigquery_run_table, bigquery_run_table=flag_obj.bigquery_run_table,
bigquery_run_status_table=flag_obj.bigquery_run_status_table,
bigquery_metric_table=flag_obj.bigquery_metric_table, bigquery_metric_table=flag_obj.bigquery_metric_table,
run_id=str(uuid.uuid4())) run_id=str(uuid.uuid4()))
else: else:
...@@ -83,6 +88,19 @@ def get_benchmark_logger(): ...@@ -83,6 +88,19 @@ def get_benchmark_logger():
return _benchmark_logger return _benchmark_logger
@contextlib.contextmanager
def benchmark_context(flag_obj):
"""Context of benchmark, which will update status of the run accordingly."""
benchmark_logger = config_benchmark_logger(flag_obj)
try:
yield
benchmark_logger.on_finish(RUN_STATUS_SUCCESS)
except Exception: # pylint: disable=broad-except
# Catch all the exception, update the run status to be failure, and re-raise
benchmark_logger.on_finish(RUN_STATUS_FAILURE)
raise
class BaseBenchmarkLogger(object): class BaseBenchmarkLogger(object):
"""Class to log the benchmark information to STDOUT.""" """Class to log the benchmark information to STDOUT."""
...@@ -127,6 +145,9 @@ class BaseBenchmarkLogger(object): ...@@ -127,6 +145,9 @@ class BaseBenchmarkLogger(object):
tf.logging.info("Benchmark run: %s", tf.logging.info("Benchmark run: %s",
_gather_run_info(model_name, dataset_name, run_params)) _gather_run_info(model_name, dataset_name, run_params))
def on_finish(self, status):
pass
class BenchmarkFileLogger(BaseBenchmarkLogger): class BenchmarkFileLogger(BaseBenchmarkLogger):
"""Class to log the benchmark information to local disk.""" """Class to log the benchmark information to local disk."""
...@@ -184,6 +205,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger): ...@@ -184,6 +205,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
tf.logging.warning("Failed to dump benchmark run info to log file: %s", tf.logging.warning("Failed to dump benchmark run info to log file: %s",
e) e)
def on_finish(self, status):
pass
class BenchmarkBigQueryLogger(BaseBenchmarkLogger): class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
"""Class to log the benchmark information to BigQuery data store.""" """Class to log the benchmark information to BigQuery data store."""
...@@ -192,12 +216,14 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger): ...@@ -192,12 +216,14 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
bigquery_uploader, bigquery_uploader,
bigquery_data_set, bigquery_data_set,
bigquery_run_table, bigquery_run_table,
bigquery_run_status_table,
bigquery_metric_table, bigquery_metric_table,
run_id): run_id):
super(BenchmarkBigQueryLogger, self).__init__() super(BenchmarkBigQueryLogger, self).__init__()
self._bigquery_uploader = bigquery_uploader self._bigquery_uploader = bigquery_uploader
self._bigquery_data_set = bigquery_data_set self._bigquery_data_set = bigquery_data_set
self._bigquery_run_table = bigquery_run_table self._bigquery_run_table = bigquery_run_table
self._bigquery_run_status_table = bigquery_run_status_table
self._bigquery_metric_table = bigquery_metric_table self._bigquery_metric_table = bigquery_metric_table
self._run_id = run_id self._run_id = run_id
...@@ -246,6 +272,20 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger): ...@@ -246,6 +272,20 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
self._bigquery_run_table, self._bigquery_run_table,
self._run_id, self._run_id,
run_info)) run_info))
thread.start_new_thread(
self._bigquery_uploader.insert_run_status,
(self._bigquery_data_set,
self._bigquery_run_status_table,
self._run_id,
RUN_STATUS_RUNNING))
def on_finish(self, status):
thread.start_new_thread(
self._bigquery_uploader.update_run_status,
(self._bigquery_data_set,
self._bigquery_run_status_table,
self._run_id,
status))
def _gather_run_info(model_name, dataset_name, run_params): def _gather_run_info(model_name, dataset_name, run_params):
......
...@@ -72,6 +72,23 @@ class BenchmarkLoggerTest(tf.test.TestCase): ...@@ -72,6 +72,23 @@ class BenchmarkLoggerTest(tf.test.TestCase):
self.assertIsInstance(logger.get_benchmark_logger(), self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkBigQueryLogger) logger.BenchmarkBigQueryLogger)
@mock.patch("official.utils.logs.logger.config_benchmark_logger")
def test_benchmark_context(self, mock_config_benchmark_logger):
mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger
with logger.benchmark_context(None):
tf.logging.info("start benchmarking")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_SUCCESS)
@mock.patch("official.utils.logs.logger.config_benchmark_logger")
def test_benchmark_context_failure(self, mock_config_benchmark_logger):
mock_logger = mock.MagicMock()
mock_config_benchmark_logger.return_value = mock_logger
with self.assertRaises(RuntimeError):
with logger.benchmark_context(None):
raise RuntimeError("training error")
mock_logger.on_finish.assert_called_once_with(logger.RUN_STATUS_FAILURE)
class BaseBenchmarkLoggerTest(tf.test.TestCase): class BaseBenchmarkLoggerTest(tf.test.TestCase):
...@@ -200,6 +217,24 @@ class BenchmarkFileLoggerTest(tf.test.TestCase): ...@@ -200,6 +217,24 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
metric_log = os.path.join(log_dir, "metric.log") metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.gfile.Exists(metric_log)) self.assertFalse(tf.gfile.Exists(metric_log))
@mock.patch("official.utils.logs.logger._gather_run_info")
def test_log_run_info(self, mock_gather_run_info):
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
log = logger.BenchmarkFileLogger(log_dir)
run_info = {"model_name": "model_name",
"dataset": "dataset_name",
"run_info": "run_value"}
mock_gather_run_info.return_value = run_info
log.log_run_info("model_name", "dataset_name", {})
run_log = os.path.join(log_dir, "benchmark_run.log")
self.assertTrue(tf.gfile.Exists(run_log))
with tf.gfile.GFile(run_log) as f:
run_info = json.loads(f.readline())
self.assertEqual(run_info["model_name"], "model_name")
self.assertEqual(run_info["dataset"], "dataset_name")
self.assertEqual(run_info["run_info"], "run_value")
def test_collect_tensorflow_info(self): def test_collect_tensorflow_info(self):
run_info = {} run_info = {}
logger._collect_tensorflow_info(run_info) logger._collect_tensorflow_info(run_info)
...@@ -274,8 +309,8 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase): ...@@ -274,8 +309,8 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
self.mock_bq_uploader = mock.MagicMock() self.mock_bq_uploader = mock.MagicMock()
self.logger = logger.BenchmarkBigQueryLogger( self.logger = logger.BenchmarkBigQueryLogger(
self.mock_bq_uploader, "dataset", "run_table", "metric_table", self.mock_bq_uploader, "dataset", "run_table", "run_status_table",
"run_id") "metric_table", "run_id")
def tearDown(self): def tearDown(self):
super(BenchmarkBigQueryLoggerTest, self).tearDown() super(BenchmarkBigQueryLoggerTest, self).tearDown()
...@@ -300,6 +335,29 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase): ...@@ -300,6 +335,29 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
self.mock_bq_uploader.upload_benchmark_metric_json.assert_called_once_with( self.mock_bq_uploader.upload_benchmark_metric_json.assert_called_once_with(
"dataset", "metric_table", "run_id", expected_metric_json) "dataset", "metric_table", "run_id", expected_metric_json)
@mock.patch("official.utils.logs.logger._gather_run_info")
def test_log_run_info(self, mock_gather_run_info):
run_info = {"model_name": "model_name",
"dataset": "dataset_name",
"run_info": "run_value"}
mock_gather_run_info.return_value = run_info
self.logger.log_run_info("model_name", "dataset_name", {})
# log_metric will call upload_benchmark_metric_json in a separate thread.
# Give it some grace period for the new thread before assert.
time.sleep(1)
self.mock_bq_uploader.upload_benchmark_run_json.assert_called_once_with(
"dataset", "run_table", "run_id", run_info)
self.mock_bq_uploader.insert_run_status.assert_called_once_with(
"dataset", "run_status_table", "run_id", "running")
def test_on_finish(self):
self.logger.on_finish(logger.RUN_STATUS_SUCCESS)
# log_metric will call upload_benchmark_metric_json in a separate thread.
# Give it some grace period for the new thread before assert.
time.sleep(1)
self.mock_bq_uploader.update_run_status.assert_called_once_with(
"dataset", "run_status_table", "run_id", logger.RUN_STATUS_SUCCESS)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -245,7 +245,7 @@ def run_wide_deep(flags_obj): ...@@ -245,7 +245,7 @@ def run_wide_deep(flags_obj):
'model_type': flags_obj.model_type, 'model_type': flags_obj.model_type,
} }
benchmark_logger = logger.config_benchmark_logger(flags_obj) benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params) benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '') loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
...@@ -280,7 +280,8 @@ def run_wide_deep(flags_obj): ...@@ -280,7 +280,8 @@ def run_wide_deep(flags_obj):
def main(_): def main(_):
run_wide_deep(flags.FLAGS) with logger.benchmark_context(flags.FLAGS):
run_wide_deep(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment