Commit 05631eec authored by liangjing's avatar liangjing
Browse files

version 1

parent 7e0391d9
#} Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import random
import string
import tensorflow as tf
from tf2_common.utils.misc import tpu_lib
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are ['ring', 'nccl']. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"['nccl', 'hierarchical_copy']. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None,
tpu_zone=None):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
'off' means not to use Distribution Strategy; 'tpu' means to use
TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not
be None if `distribution_strategy` is set to `tpu`.
tpu_zone: Optional. String that represents zone in which the TPU resides.
Must not be None if `distribution_strategy` is set to `tpu`.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is 'off' or 'one_device' and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError(
"When {} GPUs are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address, tpu_zone)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
print("enter the multi_worker_mirrored")
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.compat.v1.distribute.experimental.ParameterServerStrategy()
raise ValueError(
"Unrecognized Distribution Strategy: %r" % distribution_strategy)
def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
print("tf_config",tf_config)
if tf_config:
num_workers = (len(tf_config['cluster'].get('chief', [])) +
len(tf_config['cluster'].get('worker', [])))
elif worker_hosts:
workers = worker_hosts.split(',')
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError('Must specify task_index when number of workers > 1')
task_index = 0 if num_workers == 1 else task_index
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': workers
},
'task': {'type': 'worker', 'index': task_index}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for the Keras implementations of models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import time
from absl import logging
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2
from tensorflow.python.eager import profiler
class BatchTimestamp(object):
"""A structure to store batch time stamp."""
def __init__(self, batch_index, timestamp):
self.batch_index = batch_index
self.timestamp = timestamp
def __repr__(self):
return "'BatchTimestamp<batch_index: {}, timestamp: {}>'".format(
self.batch_index, self.timestamp)
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def __init__(self, batch_size, log_steps, logdir=None):
"""Callback for logging performance.
Args:
batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats.
logdir: Optional directory to write TensorBoard summaries.
"""
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# on_train_batch_end()
self.batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = log_steps
self.last_log_step = 0
self.steps_before_epoch = 0
self.steps_in_epoch = 0
self.start_time = None
if logdir:
self.summary_writer = tf.summary.create_file_writer(logdir)
else:
self.summary_writer = None
# Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = []
# Records the time each epoch takes to run from start to finish of epoch.
self.epoch_runtime_log = []
@property
def global_steps(self):
"""The current 1-indexed global step."""
return self.steps_before_epoch + self.steps_in_epoch
@property
def average_steps_per_second(self):
"""The average training steps per second across all epochs."""
return self.global_steps / sum(self.epoch_runtime_log)
@property
def average_examples_per_second(self):
"""The average number of training examples per second across all epochs."""
return self.average_steps_per_second * self.batch_size
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
if self.summary_writer:
self.summary_writer.flush()
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start = time.time()
def on_batch_begin(self, batch, logs=None):
if not self.start_time:
self.start_time = time.time()
# Record the timestamp of the first global step
if not self.timestamp_log:
self.timestamp_log.append(BatchTimestamp(self.global_steps,
self.start_time))
def on_batch_end(self, batch, logs=None):
"""Records elapse time of the batch and calculates examples per second."""
self.steps_in_epoch = batch + 1
steps_since_last_log = self.global_steps - self.last_log_step
if steps_since_last_log >= self.log_steps:
now = time.time()
elapsed_time = now - self.start_time
steps_per_second = steps_since_last_log / elapsed_time
examples_per_second = steps_per_second * self.batch_size
self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info(
"TimeHistory: %.2f examples/second between steps %d and %d",
examples_per_second, self.last_log_step, self.global_steps)
if self.summary_writer:
with self.summary_writer.as_default():
tf.summary.scalar('global_step/sec', steps_per_second,
self.global_steps)
tf.summary.scalar('examples/sec', examples_per_second,
self.global_steps)
self.last_log_step = self.global_steps
self.start_time = None
def on_epoch_end(self, epoch, logs=None):
epoch_run_time = time.time() - self.epoch_start
self.epoch_runtime_log.append(epoch_run_time)
self.steps_before_epoch += self.steps_in_epoch
self.steps_in_epoch = 0
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
steps_per_epoch):
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try:
profile_steps = [int(i) for i in profile_steps.split(',')]
except ValueError:
raise ValueError(profile_steps_error_message)
if len(profile_steps) != 2:
raise ValueError(profile_steps_error_message)
start_step, stop_step = profile_steps
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if enable_tensorboard:
logging.warning(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')
return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch)
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step, steps_per_epoch):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step
self.start_epoch = start_step // steps_per_epoch
self.stop_epoch = stop_step // steps_per_epoch
self.start_step_in_epoch = start_step % steps_per_epoch
self.stop_step_in_epoch = stop_step % steps_per_epoch
self.should_start = False
self.should_stop = False
def on_epoch_begin(self, epoch, logs=None):
if epoch == self.start_epoch:
self.should_start = True
if epoch == self.stop_epoch:
self.should_stop = True
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step_in_epoch and self.should_start:
self.should_start = False
profiler.start()
logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step_in_epoch and self.should_stop:
self.should_stop = False
results = profiler.stop()
profiler.save(self.log_dir, results)
logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
def set_session_config(enable_eager=False,
enable_xla=False):
"""Sets the session config."""
if is_v2_0():
set_config_v2(enable_xla=enable_xla)
else:
config = get_config_proto_v1(enable_xla=enable_xla)
if enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
def get_config_proto_v1(enable_xla=False):
"""Return config proto according to flag settings, or None to use default."""
config = None
if enable_xla:
config = tf.compat.v1.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
return config
def set_config_v2(enable_xla=False):
"""Config eager context according to flag values using TF 2.0 API."""
if enable_xla:
tf.config.optimizer.set_jit(True)
def is_v2_0():
"""Returns true if using tf 2.0."""
return tf2.enabled()
def set_gpu_thread_mode_and_count(gpu_thread_mode, num_gpus,
per_gpu_thread_count):
"""Set GPU thread mode and count, and recommend dataset threads count."""
cpu_count = multiprocessing.cpu_count()
logging.info('Logical CPU cores: %s', cpu_count)
# Allocate private thread pool for each GPU to schedule and launch kernels
per_gpu_thread_count = per_gpu_thread_count or 2
os.environ['TF_GPU_THREAD_MODE'] = gpu_thread_mode
os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
logging.info('TF_GPU_THREAD_COUNT: %s',
os.environ['TF_GPU_THREAD_COUNT'])
logging.info('TF_GPU_THREAD_MODE: %s',
os.environ['TF_GPU_THREAD_MODE'])
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
# private threads and memory copy threads.
total_gpu_thread_count = per_gpu_thread_count * num_gpus
num_runtime_threads = num_gpus
datasets_num_private_threads = min(
cpu_count - total_gpu_thread_count - num_runtime_threads,
num_gpus * 8)
logging.info('Recommended datasets_num_private_threads: %s',
datasets_num_private_threads)
return datasets_num_private_threads
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellaneous functions that can be called by models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
import tensorflow as tf
from tensorflow.python.util import nest
def past_stop_threshold(stop_threshold, eval_metric):
"""Return a boolean representing whether a model should be stopped.
Args:
stop_threshold: float, the threshold above which a model should stop
training.
eval_metric: float, the current value of the relevant metric to check.
Returns:
True if training should stop, False otherwise.
Raises:
ValueError: if either stop_threshold or eval_metric is not a number
"""
if stop_threshold is None:
return False
if not isinstance(stop_threshold, numbers.Number):
raise ValueError("Threshold for checking stop conditions must be a number.")
if not isinstance(eval_metric, numbers.Number):
raise ValueError("Eval metric being checked against stop conditions "
"must be a number.")
if eval_metric >= stop_threshold:
tf.compat.v1.logging.info(
"Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric))
return True
return False
def generate_synthetic_data(
input_shape, input_value=0, input_dtype=None, label_shape=None,
label_value=0, label_dtype=None):
"""Create a repeating dataset with constant values.
Args:
input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
the input data.
input_value: Value of each input element.
input_dtype: Input dtype. If None, will be inferred by the input value.
label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
the label data.
label_value: Value of each input element.
label_dtype: Input dtype. If None, will be inferred by the target value.
Returns:
Dataset of tensors or tuples of tensors (if label_shape is set).
"""
# TODO(kathywu): Replace with SyntheticDataset once it is in contrib.
element = input_element = nest.map_structure(
lambda s: tf.constant(input_value, input_dtype, s), input_shape)
if label_shape:
label_element = nest.map_structure(
lambda s: tf.constant(label_value, label_dtype, s), label_shape)
element = (input_element, label_element)
return tf.data.Dataset.from_tensors(element).repeat()
def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:"
" {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Initializes TPU system for TF 2.0."""
import tensorflow as tf
def tpu_initialize(tpu_address, tpu_zone=None):
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of master TPU worker.
tpu_zone: optional string. zone in which the tpu resides in.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address, zone=tpu_zone)
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
# Copyright 2018 MLBenchmark Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience function for logging compliance tags to stdout.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import json
import logging
import os
import re
import sys
import time
PATTERN = re.compile('[a-zA-Z0-9]+')
LOG_FILE = os.getenv('COMPLIANCE_FILE')
# create logger with 'spam_application'
LOGGER = logging.getLogger('mlperf_compliance')
LOGGER.setLevel(logging.DEBUG)
_STREAM_HANDLER = logging.StreamHandler(stream=sys.stdout)
_STREAM_HANDLER.setLevel(logging.INFO)
LOGGER.addHandler(_STREAM_HANDLER)
if LOG_FILE:
_FILE_HANDLER = logging.FileHandler(LOG_FILE)
_FILE_HANDLER.setLevel(logging.DEBUG)
LOGGER.addHandler(_FILE_HANDLER)
else:
_STREAM_HANDLER.setLevel(logging.DEBUG)
def get_caller(stack_index=2, root_dir=None):
caller = inspect.getframeinfo(inspect.stack()[stack_index][0])
# Trim the filenames for readability.
filename = caller.filename
if root_dir is not None:
filename = re.sub('^' + root_dir + '/', '', filename)
return (filename, caller.lineno)
# :::MLL 1556733699.71 run_start: {"value": null,
# "metadata": {"lineno": 77, "file": main.py}}
LOG_TEMPLATE = ':::MLL {:.3f} {}: {{"value": {}, "metadata": {}}}'
def mlperf_format(key, value, stack_offset=0, metadata=None):
"""Format a message for MLPerf."""
if metadata is None:
metadata = {}
if 'lineno' not in metadata:
filename, lineno = get_caller(2 + stack_offset, root_dir=None)
metadata['lineno'] = lineno
metadata['file'] = filename
now = time.time()
msg = LOG_TEMPLATE.format(now, key, json.dumps(value), json.dumps(metadata))
return msg
def mlperf_print(key, value, stack_offset=0, metadata=None):
LOGGER.info(
mlperf_format(
key, value, stack_offset=stack_offset + 1, metadata=metadata))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment