Commit 05631eec authored by liangjing's avatar liangjing
Browse files

version 1

parent 7e0391d9
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities that interact with cloud service.
"""
import requests
GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname"
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"}
def on_gcp():
"""Detect whether the current running environment is on GCP."""
try:
# Timeout in 5 seconds, in case the test environment has connectivity issue.
# There is not default timeout, which means it might block forever.
response = requests.get(
GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hook that counts examples per second every N steps or seconds."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from tf2_common.utils.logs import logger
class ExamplesPerSecondHook(tf.estimator.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(self,
batch_size,
every_n_steps=None,
every_n_secs=None,
warm_steps=0,
metric_logger=None):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size across all workers used to calculate
examples/second from global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. Exactly one of the
`every_n_steps` or `every_n_secs` should be set.
warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all
workers, not on each worker
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. If None, BaseBenchmarkLogger will
be used.
Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
both are set.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError("exactly one of every_n_steps"
" and every_n_secs should be provided.")
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._timer = tf.estimator.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
self._warm_steps = warm_steps
# List of examples per second logged every_n_steps.
self.current_examples_per_sec_list = []
def begin(self):
"""Called once before using the session to check global step."""
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run().
Args:
run_context: A SessionRunContext object.
Returns:
A SessionRunArgs object or None if never triggered.
"""
return tf.estimator.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run().
Args:
run_context: A SessionRunContext object.
run_values: A SessionRunValues object.
"""
global_step = run_values.results
if self._timer.should_trigger_for_step(
global_step) and global_step > self._warm_steps:
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
# average examples per second is based on the total (accumulative)
# training steps and training time so far
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
# current examples per second is based on the elapsed training steps
# and training time per batch
current_examples_per_sec = self._batch_size * (
elapsed_steps / elapsed_time)
# Logs entries to be read from hook during or after run.
self.current_examples_per_sec_list.append(current_examples_per_sec)
self._logger.log_metric(
"average_examples_per_sec", average_examples_per_sec,
global_step=global_step)
self._logger.log_metric(
"current_examples_per_sec", current_examples_per_sec,
global_step=global_step)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks helper to return a list of TensorFlow hooks for training by name.
More hooks can be added to this set. To add a new hook, 1) add the new hook to
the registry in HOOKS, 2) add a corresponding function that parses out necessary
parameters.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from tf2_common.utils.logs import hooks
from tf2_common.utils.logs import logger
from tf2_common.utils.logs import metric_hook
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'cross_entropy',
'train_accuracy'])
def get_train_hooks(name_list, use_tpu=False, **kwargs):
"""Factory for getting a list of TensorFlow hooks for training by name.
Args:
name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS
use_tpu: Boolean of whether computation occurs on a TPU. This will disable
hooks altogether.
**kwargs: a dictionary of arguments to the hooks.
Returns:
list of instantiated hooks, ready to be used in a classifier.train call.
Raises:
ValueError: if an unrecognized name is passed.
"""
if not name_list:
return []
if use_tpu:
tf.compat.v1.logging.warning('hooks_helper received name_list `{}`, but a '
'TPU is specified. No hooks will be used.'
.format(name_list))
return []
train_hooks = []
for name in name_list:
hook_name = HOOKS.get(name.strip().lower())
if hook_name is None:
raise ValueError('Unrecognized training hook requested: {}'.format(name))
else:
train_hooks.append(hook_name(**kwargs))
return train_hooks
def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # pylint: disable=unused-argument
"""Function to get LoggingTensorHook.
Args:
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
**kwargs: a dictionary of arguments to LoggingTensorHook.
Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be
printed to stdout.
"""
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return tf.estimator.LoggingTensorHook(
tensors=tensors_to_log,
every_n_iter=every_n_iter)
def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=unused-argument
"""Function to get ProfilerHook.
Args:
model_dir: The directory to save the profile traces to.
save_steps: `int`, print profile traces every N steps.
**kwargs: a dictionary of arguments to ProfilerHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return tf.estimator.ProfilerHook(save_steps=save_steps, output_dir=model_dir)
def get_examples_per_second_hook(every_n_steps=100,
batch_size=128,
warm_steps=5,
**kwargs): # pylint: disable=unused-argument
"""Function to get ExamplesPerSecondHook.
Args:
every_n_steps: `int`, print current and average examples per second every
N steps.
batch_size: `int`, total batch size used to calculate examples/second from
global time.
warm_steps: skip this number of steps before logging and running average.
**kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return hooks.ExamplesPerSecondHook(
batch_size=batch_size, every_n_steps=every_n_steps,
warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
def get_logging_metric_hook(tensors_to_log=None,
every_n_secs=600,
**kwargs): # pylint: disable=unused-argument
"""Function to get LoggingMetricHook.
Args:
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every
10 mins.
**kwargs: a dictionary of arguments.
Returns:
Returns a LoggingMetricHook that saves tensor values in a JSON format.
"""
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return metric_hook.LoggingMetricHook(
tensors=tensors_to_log,
metric_logger=logger.get_benchmark_logger(),
every_n_secs=every_n_secs)
def get_step_counter_hook(**kwargs):
"""Function to get StepCounterHook."""
del kwargs
return tf.estimator.StepCounterHook()
# A dictionary to map one hook name and its corresponding function
HOOKS = {
'loggingtensorhook': get_logging_tensor_hook,
'profilerhook': get_profiler_hook,
'examplespersecondhook': get_examples_per_second_hook,
'loggingmetrichook': get_logging_metric_hook,
'stepcounterhook': get_step_counter_hook
}
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logging utilities for benchmark.
For collecting local environment metrics like CPU and memory, certain python
packages need be installed. See README for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import datetime
import json
import multiprocessing
import numbers
import os
import threading
import uuid
from six.moves import _thread as thread
from absl import flags
import tensorflow as tf
from tensorflow.python.client import device_lib
from tf2_common.utils.logs import cloud_lib
METRIC_LOG_FILE_NAME = "metric.log"
BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
GCP_TEST_ENV = "GCP"
RUN_STATUS_SUCCESS = "success"
RUN_STATUS_FAILURE = "failure"
RUN_STATUS_RUNNING = "running"
FLAGS = flags.FLAGS
# Don't use it directly. Use get_benchmark_logger to access a logger.
_benchmark_logger = None
_logger_lock = threading.Lock()
def config_benchmark_logger(flag_obj=None):
"""Config the global benchmark logger."""
_logger_lock.acquire()
try:
global _benchmark_logger
if not flag_obj:
flag_obj = FLAGS
if (not hasattr(flag_obj, "benchmark_logger_type") or
flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"):
_benchmark_logger = BaseBenchmarkLogger()
elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger":
_benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir)
elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger":
from tf2_common.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top
bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project)
_benchmark_logger = BenchmarkBigQueryLogger(
bigquery_uploader=bq_uploader,
bigquery_data_set=flag_obj.bigquery_data_set,
bigquery_run_table=flag_obj.bigquery_run_table,
bigquery_run_status_table=flag_obj.bigquery_run_status_table,
bigquery_metric_table=flag_obj.bigquery_metric_table,
run_id=str(uuid.uuid4()))
else:
raise ValueError("Unrecognized benchmark_logger_type: %s"
% flag_obj.benchmark_logger_type)
finally:
_logger_lock.release()
return _benchmark_logger
def get_benchmark_logger():
if not _benchmark_logger:
config_benchmark_logger()
return _benchmark_logger
@contextlib.contextmanager
def benchmark_context(flag_obj):
"""Context of benchmark, which will update status of the run accordingly."""
benchmark_logger = config_benchmark_logger(flag_obj)
try:
yield
benchmark_logger.on_finish(RUN_STATUS_SUCCESS)
except Exception: # pylint: disable=broad-except
# Catch all the exception, update the run status to be failure, and re-raise
benchmark_logger.on_finish(RUN_STATUS_FAILURE)
raise
class BaseBenchmarkLogger(object):
"""Class to log the benchmark information to STDOUT."""
def log_evaluation_result(self, eval_results):
"""Log the evaluation result.
The evaluate result is a dictionary that contains metrics defined in
model_fn. It also contains a entry for global_step which contains the value
of the global step when evaluation was performed.
Args:
eval_results: dict, the result of evaluate.
"""
if not isinstance(eval_results, dict):
tf.compat.v1.logging.warning(
"eval_results should be dictionary for logging. Got %s",
type(eval_results))
return
global_step = eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP]
for key in sorted(eval_results):
if key != tf.compat.v1.GraphKeys.GLOBAL_STEP:
self.log_metric(key, eval_results[key], global_step=global_step)
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
"""Log the benchmark metric information to local file.
Currently the logging is done in a synchronized way. This should be updated
to log asynchronously.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
tf.compat.v1.logging.info("Benchmark metric: %s", metric)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
tf.compat.v1.logging.info(
"Benchmark run: %s", _gather_run_info(model_name, dataset_name,
run_params, test_id))
def on_finish(self, status):
pass
class BenchmarkFileLogger(BaseBenchmarkLogger):
"""Class to log the benchmark information to local disk."""
def __init__(self, logging_dir):
super(BenchmarkFileLogger, self).__init__()
self._logging_dir = logging_dir
if not tf.io.gfile.isdir(self._logging_dir):
tf.io.gfile.makedirs(self._logging_dir)
self._metric_file_handler = tf.io.gfile.GFile(
os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a")
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
"""Log the benchmark metric information to local file.
Currently the logging is done in a synchronized way. This should be updated
to log asynchronously.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
try:
json.dump(metric, self._metric_file_handler)
self._metric_file_handler.write("\n")
self._metric_file_handler.flush()
except (TypeError, ValueError) as e:
tf.compat.v1.logging.warning(
"Failed to dump metric to log file: name %s, value %s, error %s",
name, value, e)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
with tf.io.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
try:
json.dump(run_info, f)
f.write("\n")
except (TypeError, ValueError) as e:
tf.compat.v1.logging.warning(
"Failed to dump benchmark run info to log file: %s", e)
def on_finish(self, status):
self._metric_file_handler.flush()
self._metric_file_handler.close()
class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
"""Class to log the benchmark information to BigQuery data store."""
def __init__(self,
bigquery_uploader,
bigquery_data_set,
bigquery_run_table,
bigquery_run_status_table,
bigquery_metric_table,
run_id):
super(BenchmarkBigQueryLogger, self).__init__()
self._bigquery_uploader = bigquery_uploader
self._bigquery_data_set = bigquery_data_set
self._bigquery_run_table = bigquery_run_table
self._bigquery_run_status_table = bigquery_run_status_table
self._bigquery_metric_table = bigquery_metric_table
self._run_id = run_id
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
"""Log the benchmark metric information to bigquery.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
# Starting new thread for bigquery upload in case it might take long time
# and impact the benchmark and performance measurement. Starting a new
# thread might have potential performance impact for model that run on
# CPU.
thread.start_new_thread(
self._bigquery_uploader.upload_benchmark_metric_json,
(self._bigquery_data_set,
self._bigquery_metric_table,
self._run_id,
[metric]))
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
# Starting new thread for bigquery upload in case it might take long time
# and impact the benchmark and performance measurement. Starting a new
# thread might have potential performance impact for model that run on CPU.
thread.start_new_thread(
self._bigquery_uploader.upload_benchmark_run_json,
(self._bigquery_data_set,
self._bigquery_run_table,
self._run_id,
run_info))
thread.start_new_thread(
self._bigquery_uploader.insert_run_status,
(self._bigquery_data_set,
self._bigquery_run_status_table,
self._run_id,
RUN_STATUS_RUNNING))
def on_finish(self, status):
self._bigquery_uploader.update_run_status(
self._bigquery_data_set,
self._bigquery_run_status_table,
self._run_id,
status)
def _gather_run_info(model_name, dataset_name, run_params, test_id):
"""Collect the benchmark run information for the local environment."""
run_info = {
"model_name": model_name,
"dataset": {"name": dataset_name},
"machine_config": {},
"test_id": test_id,
"run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)}
_collect_tensorflow_info(run_info)
_collect_tensorflow_environment_variables(run_info)
_collect_run_params(run_info, run_params)
_collect_cpu_info(run_info)
_collect_memory_info(run_info)
_collect_test_environment(run_info)
return run_info
def _process_metric_to_json(
name, value, unit=None, global_step=None, extras=None):
"""Validate the metric data and generate JSON for insert."""
if not isinstance(value, numbers.Number):
tf.compat.v1.logging.warning(
"Metric value to log should be a number. Got %s", type(value))
return None
extras = _convert_to_json_dict(extras)
return {
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"timestamp": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN),
"extras": extras}
def _collect_tensorflow_info(run_info):
run_info["tensorflow_version"] = {
"version": tf.version.VERSION, "git_hash": tf.version.GIT_VERSION}
def _collect_run_params(run_info, run_params):
"""Log the parameter information for the benchmark run."""
def process_param(name, value):
type_check = {
str: {"name": name, "string_value": value},
int: {"name": name, "long_value": value},
bool: {"name": name, "bool_value": str(value)},
float: {"name": name, "float_value": value},
}
return type_check.get(type(value),
{"name": name, "string_value": str(value)})
if run_params:
run_info["run_parameters"] = [
process_param(k, v) for k, v in sorted(run_params.items())]
def _collect_tensorflow_environment_variables(run_info):
run_info["tensorflow_environment_variables"] = [
{"name": k, "value": v}
for k, v in sorted(os.environ.items()) if k.startswith("TF_")]
# The following code is mirrored from tensorflow/tools/test/system_info_lib
# which is not exposed for import.
def _collect_cpu_info(run_info):
"""Collect the CPU information for the local environment."""
cpu_info = {}
cpu_info["num_cores"] = multiprocessing.cpu_count()
try:
# Note: cpuinfo is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import cpuinfo # pylint: disable=g-import-not-at-top
info = cpuinfo.get_cpu_info()
cpu_info["cpu_info"] = info["brand"]
cpu_info["mhz_per_cpu"] = info["hz_advertised_raw"][0] / 1.0e6
run_info["machine_config"]["cpu_info"] = cpu_info
except ImportError:
tf.compat.v1.logging.warn(
"'cpuinfo' not imported. CPU info will not be logged.")
def _collect_memory_info(run_info):
try:
# Note: psutil is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import psutil # pylint: disable=g-import-not-at-top
vmem = psutil.virtual_memory()
run_info["machine_config"]["memory_total"] = vmem.total
run_info["machine_config"]["memory_available"] = vmem.available
except ImportError:
tf.compat.v1.logging.warn(
"'psutil' not imported. Memory info will not be logged.")
def _collect_test_environment(run_info):
"""Detect the local environment, eg GCE, AWS or DGX, etc."""
if cloud_lib.on_gcp():
run_info["test_environment"] = GCP_TEST_ENV
# TODO(scottzhu): Add more testing env detection for other platform
def _parse_gpu_model(physical_device_desc):
# Assume all the GPU connected are same model
for kv in physical_device_desc.split(","):
k, _, v = kv.partition(":")
if k.strip() == "name":
return v.strip()
return None
def _convert_to_json_dict(input_dict):
if input_dict:
return [{"name": k, "value": v} for k, v in sorted(input_dict.items())]
else:
return []
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Session hook for logging benchmark metric."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
class LoggingMetricHook(tf.estimator.LoggingTensorHook):
"""Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given
tensors every N local steps, every N seconds, or at the end. The metric
information will be logged to given log_dir or via metric_logger in JSON
format, which can be consumed by data analysis pipeline later.
Note that if `at_end` is True, `tensors` should not include any tensor
whose evaluation produces a side effect such as consuming additional inputs.
"""
def __init__(self, tensors, metric_logger=None,
every_n_iter=None, every_n_secs=None, at_end=False):
"""Initializer for LoggingMetricHook.
Args:
tensors: `dict` that maps string-valued tags to tensors/tensor names,
or `iterable` of tensors/tensor names.
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log.
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
every_n_secs: `int` or `float`, print the values of `tensors` once every N
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided.
at_end: `bool` specifying whether to print the values of `tensors` at the
end of the run.
Raises:
ValueError:
1. `every_n_iter` is non-positive, or
2. Exactly one of every_n_iter and every_n_secs should be provided.
3. Exactly one of log_dir and metric_logger should be provided.
"""
super(LoggingMetricHook, self).__init__(
tensors=tensors,
every_n_iter=every_n_iter,
every_n_secs=every_n_secs,
at_end=at_end)
if metric_logger is None:
raise ValueError("metric_logger should be provided.")
self._logger = metric_logger
def begin(self):
super(LoggingMetricHook, self).begin()
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use LoggingMetricHook.")
if self._global_step_tensor.name not in self._current_tensors:
self._current_tensors[self._global_step_tensor.name] = (
self._global_step_tensor)
def after_run(self, unused_run_context, run_values):
# should_trigger is a internal state that populated at before_run, and it is
# using self_timer to determine whether it should trigger.
if self._should_trigger:
self._log_metric(run_values.results)
self._iter_count += 1
def end(self, session):
if self._log_at_end:
values = session.run(self._current_tensors)
self._log_metric(values)
def _log_metric(self, tensor_values):
self._timer.update_last_triggered_step(self._iter_count)
global_step = tensor_values[self._global_step_tensor.name]
# self._tag_order is populated during the init of LoggingTensorHook
for tag in self._tag_order:
self._logger.log_metric(tag, tensor_values[tag], global_step=global_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment