Unverified Commit d2d6ab4c authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Add new test ID and test env info to the benchmark run. (#4426)

* Add new test ID and test env info to the benchmark run.

* Fix test.

* Fix lint

* Address review comment.
parent 47c5642e
......@@ -210,7 +210,8 @@ def train_boosted_trees(flags_obj):
benchmark_logger.log_run_info(
model_name="boosted_trees",
dataset_name="higgs",
run_params=run_params)
run_params=run_params,
test_id=flags_obj.benchmark_test_id)
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
......@@ -244,6 +245,7 @@ def main(_):
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_integer(
......
......@@ -247,7 +247,8 @@ def run_ncf(_):
benchmark_logger.log_run_info(
model_name="recommendation",
dataset_name=FLAGS.dataset,
run_params=run_params)
run_params=run_params,
test_id=FLAGS.benchmark_test_id)
# Training and evaluation cycle
def train_input_fn():
......
......@@ -395,8 +395,10 @@ def resnet_main(
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
benchmark_logger.log_run_info('resnet', dataset_name, run_params,
test_id=flags_obj.benchmark_test_id)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
......
......@@ -440,7 +440,8 @@ def run_transformer(flags_obj):
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
run_params=params.__dict__)
run_params=params.__dict__,
test_id=flags_obj.benchmark_test_id)
# Train and evaluate transformer model
estimator = tf.estimator.Estimator(
......
......@@ -43,6 +43,14 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
help=help_wrap("The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work."))
flags.DEFINE_string(
name="benchmark_test_id", short_name="bti", default=None,
help=help_wrap("The unique test ID of the benchmark run. It could be the "
"combination of key parameters. It is hardware "
"independent and could be used compare the performance "
"between different test runs. This flag is designed for "
"human consumption, and does not have any impact within "
"the system."))
if benchmark_log_dir:
flags.DEFINE_string(
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities that interact with cloud service.
"""
import requests
GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname"
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"}
def on_gcp():
"""Detect whether the current running environment is on GCP"""
try:
response = requests.get(GCP_METADATA_URL, headers=GCP_METADATA_HEADER)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cloud_lib."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import mock
import requests
from official.utils.logs import cloud_lib
class CloudLibTest(unittest.TestCase):
@mock.patch("requests.get")
def test_on_gcp(self, mock_requests_get):
mock_response = mock.MagicMock()
mock_requests_get.return_value = mock_response
mock_response.status_code = 200
self.assertEqual(cloud_lib.on_gcp(), True)
@mock.patch("requests.get")
def test_not_on_gcp(self, mock_requests_get):
mock_requests_get.side_effect = requests.exceptions.ConnectionError()
self.assertEqual(cloud_lib.on_gcp(), False)
if __name__ == "__main__":
unittest.main()
......@@ -36,13 +36,17 @@ from absl import flags
import tensorflow as tf
from tensorflow.python.client import device_lib
from official.utils.logs import cloud_lib
METRIC_LOG_FILE_NAME = "metric.log"
BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
GCP_TEST_ENV = "GCP"
RUN_STATUS_SUCCESS = "success"
RUN_STATUS_FAILURE = "failure"
RUN_STATUS_RUNNING = "running"
FLAGS = flags.FLAGS
# Don't use it directly. Use get_benchmark_logger to access a logger.
......@@ -141,9 +145,10 @@ class BaseBenchmarkLogger(object):
if metric:
tf.logging.info("Benchmark metric: %s", metric)
def log_run_info(self, model_name, dataset_name, run_params):
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
tf.logging.info("Benchmark run: %s",
_gather_run_info(model_name, dataset_name, run_params))
_gather_run_info(model_name, dataset_name, run_params,
test_id))
def on_finish(self, status):
pass
......@@ -183,7 +188,7 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
tf.logging.warning("Failed to dump metric to log file: "
"name %s, value %s, error %s", name, value, e)
def log_run_info(self, model_name, dataset_name, run_params):
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
......@@ -193,8 +198,10 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params)
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
with tf.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
......@@ -251,7 +258,7 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
self._run_id,
[metric]))
def log_run_info(self, model_name, dataset_name, run_params):
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
......@@ -261,8 +268,10 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params)
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
# Starting new thread for bigquery upload in case it might take long time
# and impact the benchmark and performance measurement. Starting a new
# thread might have potential performance impact for model that run on CPU.
......@@ -288,12 +297,13 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
status))
def _gather_run_info(model_name, dataset_name, run_params):
def _gather_run_info(model_name, dataset_name, run_params, test_id):
"""Collect the benchmark run information for the local environment."""
run_info = {
"model_name": model_name,
"dataset": {"name": dataset_name},
"machine_config": {},
"test_id": test_id,
"run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)}
_collect_tensorflow_info(run_info)
......@@ -302,6 +312,7 @@ def _gather_run_info(model_name, dataset_name, run_params):
_collect_cpu_info(run_info)
_collect_gpu_info(run_info)
_collect_memory_info(run_info)
_collect_test_environment(run_info)
return run_info
......@@ -403,6 +414,13 @@ def _collect_memory_info(run_info):
tf.logging.warn("'psutil' not imported. Memory info will not be logged.")
def _collect_test_environment(run_info):
"""Detect the local environment, eg GCE, AWS or DGX, etc."""
if cloud_lib.on_gcp():
run_info["test_environment"] = GCP_TEST_ENV
# TODO(scottzhu): Add more testing env detection for other platform
def _parse_gpu_model(physical_device_desc):
# Assume all the GPU connected are same model
for kv in physical_device_desc.split(","):
......
......@@ -246,7 +246,8 @@ def run_wide_deep(flags_obj):
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params,
test_id=flags_obj.benchmark_test_id)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment