"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "899cf5c4389b83fc7d913bf23ce07282d09ffb91"
Unverified Commit eb0c0dfd authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Add dataset info and hyper parameter logging for benchmark. (#4152)

* Add dataset info and hyper parameter logging for benchmark.

* Address review comments.

* Address the view comment for data schema name.

* Fix test cases.

* Lint fix.
parent 8e73530e
......@@ -98,41 +98,41 @@
"type": "RECORD"
},
{
"description": "The list of hyperparameters of the model.",
"description": "The list of parameters run with the model. It could contain hyperparameters or others.",
"fields": [
{
"description": "The name of the hyperparameter.",
"description": "The name of the parameter.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The string value of the hyperparameter.",
"description": "The string value of the parameter.",
"mode": "NULLABLE",
"name": "string_value",
"type": "STRING"
},
{
"description": "The bool value of the hyperparameter.",
"description": "The bool value of the parameter.",
"mode": "NULLABLE",
"name": "bool_value",
"type": "STRING"
},
{
"description": "The int/long value of the hyperparameter.",
"description": "The int/long value of the parameter.",
"mode": "NULLABLE",
"name": "long_value",
"type": "INTEGER"
},
{
"description": "The double/float value of hyperparameter.",
"description": "The double/float value of parameter.",
"mode": "NULLABLE",
"name": "float_value",
"type": "FLOAT"
}
],
"mode": "REPEATED",
"name": "hyperparameter",
"name": "run_parameters",
"type": "RECORD"
},
{
......
......@@ -42,6 +42,8 @@ _NUM_IMAGES = {
'validation': 10000,
}
DATASET_NAME = 'CIFAR-10'
###############################################################################
# Data processing
......@@ -237,7 +239,7 @@ def run_cifar(flags_obj):
or input_fn)
resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function,
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
......
......@@ -41,6 +41,7 @@ _NUM_IMAGES = {
_NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 1500
DATASET_NAME = 'ImageNet'
###############################################################################
# Data processing
......@@ -312,7 +313,7 @@ def run_imagenet(flags_obj):
or input_fn)
resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function,
flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
......
......@@ -331,7 +331,8 @@ def per_device_batch_size(batch_size, num_gpus):
return int(batch_size / num_gpus)
def resnet_main(flags_obj, model_function, input_function, shape=None):
def resnet_main(
flags_obj, model_function, input_function, dataset_name, shape=None):
"""Shared main loop for ResNet Models.
Args:
......@@ -342,6 +343,8 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
input_function: the function that processes the dataset and returns a
dataset that the estimator can train on. This will be wrapped with
all the relevant flags for running and passed to estimator.
dataset_name: the name of the dataset for training and evaluation. This is
used for logging purpose.
shape: list of ints representing the shape of the images used for training.
This is only used if flags_obj.export_dir is passed.
"""
......@@ -381,8 +384,16 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
'dtype': flags_core.get_tf_dtype(flags_obj)
})
run_params = {
'batch_size': flags_obj.batch_size,
'dtype': flags_core.get_tf_dtype(flags_obj),
'resnet_size': flags_obj.resnet_size,
'resnet_version': flags_obj.version,
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
benchmark_logger.log_run_info('resnet')
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
......
......@@ -109,8 +109,9 @@ class BaseBenchmarkLogger(object):
"Name %s, value %d, unit %s, global_step %d, extras %s",
name, value, unit, global_step, extras)
def log_run_info(self, model_name):
tf.logging.info("Benchmark run: %s", _gather_run_info(model_name))
def log_run_info(self, model_name, dataset_name, run_params):
tf.logging.info("Benchmark run: %s",
_gather_run_info(model_name, dataset_name, run_params))
class BenchmarkFileLogger(BaseBenchmarkLogger):
......@@ -159,15 +160,18 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
tf.logging.warning("Failed to dump metric to log file: "
"name %s, value %s, error %s", name, value, e)
def log_run_info(self, model_name):
def log_run_info(self, model_name, dataset_name, run_params):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
"""
run_info = _gather_run_info(model_name)
run_info = _gather_run_info(model_name, dataset_name, run_params)
with tf.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
......@@ -179,15 +183,17 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
e)
def _gather_run_info(model_name):
def _gather_run_info(model_name, dataset_name, run_params):
"""Collect the benchmark run information for the local environment."""
run_info = {
"model_name": model_name,
"dataset": {"name": dataset_name},
"machine_config": {},
"run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)}
_collect_tensorflow_info(run_info)
_collect_tensorflow_environment_variables(run_info)
_collect_run_params(run_info, run_params)
_collect_cpu_info(run_info)
_collect_gpu_info(run_info)
_collect_memory_info(run_info)
......@@ -199,6 +205,21 @@ def _collect_tensorflow_info(run_info):
"version": tf.VERSION, "git_hash": tf.GIT_VERSION}
def _collect_run_params(run_info, run_params):
"""Log the parameter information for the benchmark run."""
def process_param(name, value):
type_check = {
str: {"name": name, "string_value": value},
int: {"name": name, "long_value": value},
bool: {"name": name, "bool_value": str(value)},
float: {"name": name, "float_value": value},
}
return type_check.get(type(value),
{"name": name, "string_value": str(value)})
if run_params:
run_info["run_parameters"] = [
process_param(k, v) for k, v in sorted(run_params.items())]
def _collect_tensorflow_environment_variables(run_info):
run_info["tensorflow_environment_variables"] = [
{"name": k, "value": v}
......
......@@ -180,6 +180,32 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
self.assertEqual(run_info["tensorflow_version"]["version"], tf.VERSION)
self.assertEqual(run_info["tensorflow_version"]["git_hash"], tf.GIT_VERSION)
def test_collect_run_params(self):
run_info = {}
run_parameters = {
"batch_size": 32,
"synthetic_data": True,
"train_epochs": 100.00,
"dtype": "fp16",
"resnet_size": 50,
"random_tensor": tf.constant(2.0)
}
logger._collect_run_params(run_info, run_parameters)
self.assertEqual(len(run_info["run_parameters"]), 6)
self.assertEqual(run_info["run_parameters"][0],
{"name": "batch_size", "long_value": 32})
self.assertEqual(run_info["run_parameters"][1],
{"name": "dtype", "string_value": "fp16"})
self.assertEqual(run_info["run_parameters"][2],
{"name": "random_tensor", "string_value":
"Tensor(\"Const:0\", shape=(), dtype=float32)"})
self.assertEqual(run_info["run_parameters"][3],
{"name": "resnet_size", "long_value": 50})
self.assertEqual(run_info["run_parameters"][4],
{"name": "synthetic_data", "bool_value": "True"})
self.assertEqual(run_info["run_parameters"][5],
{"name": "train_epochs", "float_value": 100.00})
def test_collect_tensorflow_environment_variables(self):
os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "1"
os.environ["TF_OTHER"] = "2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment