Commit 441c8f40 authored by qianyj's avatar qianyj
Browse files

update TF code

parent ec90ad8e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test TPU optimized matmul embedding."""
import numpy as np
import tensorflow as tf
from official.utils.accelerator import tpu as tpu_utils
TEST_CASES = [
dict(embedding_dim=256, vocab_size=1000, sequence_length=64,
batch_size=32, seed=54131),
dict(embedding_dim=8, vocab_size=15, sequence_length=12,
batch_size=256, seed=536413),
dict(embedding_dim=2048, vocab_size=512, sequence_length=50,
batch_size=8, seed=35124)
]
class TPUBaseTester(tf.test.TestCase):
def construct_embedding_and_values(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
np.random.seed(seed)
embeddings = np.random.random(size=(vocab_size, embedding_dim))
embedding_table = tf.convert_to_tensor(embeddings, dtype=tf.float32)
tokens = np.random.randint(low=1, high=vocab_size-1,
size=(batch_size, sequence_length))
for i in range(batch_size):
tokens[i, np.random.randint(low=0, high=sequence_length-1):] = 0
values = tf.convert_to_tensor(tokens, dtype=tf.int32)
mask = tf.to_float(tf.not_equal(values, 0))
return embedding_table, values, mask
def _test_embedding(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding matches embedding lookup (gather)."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) *
tf.expand_dims(mask, -1))
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(embedding, matmul_embedding)
def _test_masking(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding properly zeros masked positions."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(matmul_embedding,
matmul_embedding * tf.expand_dims(mask, -1))
def test_embedding_0(self):
self._test_embedding(**TEST_CASES[0])
def test_embedding_1(self):
self._test_embedding(**TEST_CASES[1])
def test_embedding_2(self):
self._test_embedding(**TEST_CASES[2])
def test_masking_0(self):
self._test_masking(**TEST_CASES[0])
def test_masking_1(self):
self._test_masking(**TEST_CASES[1])
def test_masking_2(self):
self._test_masking(**TEST_CASES[2])
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for managing dataset file buffers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import atexit
import multiprocessing
import os
import tempfile
import uuid
import numpy as np
import six
import tensorflow as tf
class _GarbageCollector(object):
"""Deletes temporary buffer files at exit.
Certain tasks (such as NCF Recommendation) require writing buffers to
temporary files. (Which may be local or distributed.) It is not generally safe
to delete these files during operation, but they should be cleaned up. This
class keeps track of temporary files created, and deletes them at exit.
"""
def __init__(self):
self.temp_buffers = []
def register(self, filepath):
self.temp_buffers.append(filepath)
def purge(self):
try:
for i in self.temp_buffers:
if tf.gfile.Exists(i):
tf.gfile.Remove(i)
tf.logging.info("Buffer file {} removed".format(i))
except Exception as e:
tf.logging.error("Failed to cleanup buffer files: {}".format(e))
_GARBAGE_COLLECTOR = _GarbageCollector()
atexit.register(_GARBAGE_COLLECTOR.purge)
_ROWS_PER_CORE = 50000
def write_to_temp_buffer(dataframe, buffer_folder, columns):
if buffer_folder is None:
_, buffer_path = tempfile.mkstemp()
else:
tf.gfile.MakeDirs(buffer_folder)
buffer_path = os.path.join(buffer_folder, str(uuid.uuid4()))
_GARBAGE_COLLECTOR.register(buffer_path)
return write_to_buffer(dataframe, buffer_path, columns)
def iter_shard_dataframe(df, rows_per_core=1000):
"""Two way shard of a dataframe.
This function evenly shards a dataframe so that it can be mapped efficiently.
It yields a list of dataframes with length equal to the number of CPU cores,
with each dataframe having rows_per_core rows. (Except for the last batch
which may have fewer rows in the dataframes.) Passing vectorized inputs to
a multiprocessing pool is much more effecient than iterating through a
dataframe in serial and passing a list of inputs to the pool.
Args:
df: Pandas dataframe to be sharded.
rows_per_core: Number of rows in each shard.
Returns:
A list of dataframe shards.
"""
n = len(df)
num_cores = min([multiprocessing.cpu_count(), n])
num_blocks = int(np.ceil(n / num_cores / rows_per_core))
max_batch_size = num_cores * rows_per_core
for i in range(num_blocks):
min_index = i * max_batch_size
max_index = min([(i + 1) * max_batch_size, n])
df_shard = df[min_index:max_index]
n_shard = len(df_shard)
boundaries = np.linspace(0, n_shard, num_cores + 1, dtype=np.int64)
yield [df_shard[boundaries[j]:boundaries[j+1]] for j in range(num_cores)]
def _shard_dict_to_examples(shard_dict):
"""Converts a dict of arrays into a list of example bytes."""
n = [i for i in shard_dict.values()][0].shape[0]
feature_list = [{} for _ in range(n)]
for column, values in shard_dict.items():
if len(values.shape) == 1:
values = np.reshape(values, values.shape + (1,))
if values.dtype.kind == "i":
feature_map = lambda x: tf.train.Feature(
int64_list=tf.train.Int64List(value=x))
elif values.dtype.kind == "f":
feature_map = lambda x: tf.train.Feature(
float_list=tf.train.FloatList(value=x))
else:
raise ValueError("Invalid dtype")
for i in range(n):
feature_list[i][column] = feature_map(values[i])
examples = [
tf.train.Example(features=tf.train.Features(feature=example_features))
for example_features in feature_list
]
return [e.SerializeToString() for e in examples]
def _serialize_shards(df_shards, columns, pool, writer):
"""Map sharded dataframes to bytes, and write them to a buffer.
Args:
df_shards: A list of pandas dataframes. (Should be of similar size)
columns: The dataframe columns to be serialized.
pool: A multiprocessing pool to serialize in parallel.
writer: A TFRecordWriter to write the serialized shards.
"""
# Pandas does not store columns of arrays as nd arrays. stack remedies this.
map_inputs = [{c: np.stack(shard[c].values, axis=0) for c in columns}
for shard in df_shards]
# Failure within pools is very irksome. Thus, it is better to thoroughly check
# inputs in the main process.
for inp in map_inputs:
# Check that all fields have the same number of rows.
assert len(set([v.shape[0] for v in inp.values()])) == 1
for val in inp.values():
assert hasattr(val, "dtype")
assert hasattr(val.dtype, "kind")
assert val.dtype.kind in ("i", "f")
assert len(val.shape) in (1, 2)
shard_bytes = pool.map(_shard_dict_to_examples, map_inputs)
for s in shard_bytes:
for example in s:
writer.write(example)
def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
"""Write a dataframe to a binary file for a dataset to consume.
Args:
dataframe: The pandas dataframe to be serialized.
buffer_path: The path where the serialized results will be written.
columns: The dataframe columns to be serialized.
expected_size: The size in bytes of the serialized results. This is used to
lazily construct the buffer.
Returns:
The path of the buffer.
"""
if tf.gfile.Exists(buffer_path) and tf.gfile.Stat(buffer_path).length > 0:
actual_size = tf.gfile.Stat(buffer_path).length
if expected_size == actual_size:
return buffer_path
tf.logging.warning(
"Existing buffer {} has size {}. Expected size {}. Deleting and "
"rebuilding buffer.".format(buffer_path, actual_size, expected_size))
tf.gfile.Remove(buffer_path)
if dataframe is None:
raise ValueError(
"dataframe was None but a valid existing buffer was not found.")
tf.gfile.MakeDirs(os.path.split(buffer_path)[0])
tf.logging.info("Constructing TFRecordDataset buffer: {}".format(buffer_path))
count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count())
try:
with tf.python_io.TFRecordWriter(buffer_path) as writer:
for df_shards in iter_shard_dataframe(df=dataframe,
rows_per_core=_ROWS_PER_CORE):
_serialize_shards(df_shards, columns, pool, writer)
count += sum([len(s) for s in df_shards])
tf.logging.info("{}/{} examples written."
.format(str(count).ljust(8), len(dataframe)))
finally:
pool.terminate()
tf.logging.info("Buffer write complete.")
return buffer_path
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for binary data file utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import multiprocessing
# pylint: disable=wrong-import-order
import numpy as np
import pandas as pd
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.utils.data import file_io
_RAW_ROW = "raw_row"
_DUMMY_COL = "column_0"
_DUMMY_VEC_COL = "column_1"
_DUMMY_VEC_LEN = 4
_ROWS_PER_CORE = 4
_TEST_CASES = [
# One batch of one
dict(row_count=1, cpu_count=1, expected=[
[[0]]
]),
dict(row_count=10, cpu_count=1, expected=[
[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]]
]),
dict(row_count=21, cpu_count=1, expected=[
[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]],
[[12, 13, 14, 15]], [[16, 17, 18, 19]], [[20]]
]),
dict(row_count=1, cpu_count=4, expected=[
[[0]]
]),
dict(row_count=10, cpu_count=4, expected=[
[[0, 1], [2, 3, 4], [5, 6], [7, 8, 9]]
]),
dict(row_count=21, cpu_count=4, expected=[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
[[16], [17], [18], [19, 20]]
]),
dict(row_count=10, cpu_count=8, expected=[
[[0], [1], [2], [3, 4], [5], [6], [7], [8, 9]]
]),
dict(row_count=40, cpu_count=8, expected=[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15],
[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27],
[28, 29, 30, 31]],
[[32], [33], [34], [35], [36], [37], [38], [39]]
]),
]
_FEATURE_MAP = {
_RAW_ROW: tf.FixedLenFeature([1], dtype=tf.int64),
_DUMMY_COL: tf.FixedLenFeature([1], dtype=tf.int64),
_DUMMY_VEC_COL: tf.FixedLenFeature([_DUMMY_VEC_LEN], dtype=tf.float32)
}
@contextlib.contextmanager
def fixed_core_count(cpu_count):
"""Override CPU count.
file_io.py uses the cpu_count function to scale to the size of the instance.
However, this is not desirable for testing because it can make the test flaky.
Instead, this context manager fixes the count for more robust testing.
Args:
cpu_count: How many cores multiprocessing claims to have.
Yields:
Nothing. (for context manager only)
"""
old_count_fn = multiprocessing.cpu_count
multiprocessing.cpu_count = lambda: cpu_count
yield
multiprocessing.cpu_count = old_count_fn
class BaseTest(tf.test.TestCase):
def _test_sharding(self, row_count, cpu_count, expected):
df = pd.DataFrame({_DUMMY_COL: list(range(row_count))})
with fixed_core_count(cpu_count):
shards = list(file_io.iter_shard_dataframe(df, _ROWS_PER_CORE))
result = [[j[_DUMMY_COL].tolist() for j in i] for i in shards]
self.assertAllEqual(expected, result)
def test_tiny_rows_low_core(self):
self._test_sharding(**_TEST_CASES[0])
def test_small_rows_low_core(self):
self._test_sharding(**_TEST_CASES[1])
def test_large_rows_low_core(self):
self._test_sharding(**_TEST_CASES[2])
def test_tiny_rows_medium_core(self):
self._test_sharding(**_TEST_CASES[3])
def test_small_rows_medium_core(self):
self._test_sharding(**_TEST_CASES[4])
def test_large_rows_medium_core(self):
self._test_sharding(**_TEST_CASES[5])
def test_small_rows_large_core(self):
self._test_sharding(**_TEST_CASES[6])
def test_large_rows_large_core(self):
self._test_sharding(**_TEST_CASES[7])
def _serialize_deserialize(self, num_cores=1, num_rows=20):
np.random.seed(1)
df = pd.DataFrame({
# Serialization order is only deterministic for num_cores=1. raw_row is
# used in validation after the deserialization.
_RAW_ROW: np.array(range(num_rows), dtype=np.int64),
_DUMMY_COL: np.random.randint(0, 35, size=(num_rows,)),
_DUMMY_VEC_COL: [
np.array([np.random.random() for _ in range(_DUMMY_VEC_LEN)])
for i in range(num_rows) # pylint: disable=unused-variable
]
})
with fixed_core_count(num_cores):
buffer_path = file_io.write_to_temp_buffer(
df, self.get_temp_dir(), [_RAW_ROW, _DUMMY_COL, _DUMMY_VEC_COL])
with self.test_session(graph=tf.Graph()) as sess:
dataset = tf.data.TFRecordDataset(buffer_path)
dataset = dataset.batch(1).map(
lambda x: tf.parse_example(x, _FEATURE_MAP))
data_iter = dataset.make_one_shot_iterator()
seen_rows = set()
for i in range(num_rows+5):
row = data_iter.get_next()
try:
row_id, val_0, val_1 = sess.run(
[row[_RAW_ROW], row[_DUMMY_COL], row[_DUMMY_VEC_COL]])
row_id, val_0, val_1 = row_id[0][0], val_0[0][0], val_1[0]
assert row_id not in seen_rows
seen_rows.add(row_id)
self.assertEqual(val_0, df[_DUMMY_COL][row_id])
self.assertAllClose(val_1, df[_DUMMY_VEC_COL][row_id])
self.assertLess(i, num_rows, msg="Too many rows.")
except tf.errors.OutOfRangeError:
self.assertGreaterEqual(i, num_rows, msg="Too few rows.")
file_io._GARBAGE_COLLECTOR.purge()
assert not tf.gfile.Exists(buffer_path)
def test_serialize_deserialize_0(self):
self._serialize_deserialize(num_cores=1)
def test_serialize_deserialize_1(self):
self._serialize_deserialize(num_cores=2)
def test_serialize_deserialize_2(self):
self._serialize_deserialize(num_cores=8)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exporting utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.export import export
class ExportUtilsTest(tf.test.TestCase):
"""Tests for the ExportUtils."""
def test_build_tensor_serving_input_receiver_fn(self):
receiver_fn = export.build_tensor_serving_input_receiver_fn(shape=[4, 5])
with tf.Graph().as_default():
receiver = receiver_fn()
self.assertIsInstance(
receiver, tf.estimator.export.TensorServingInputReceiver)
self.assertIsInstance(receiver.features, tf.Tensor)
self.assertEqual(receiver.features.shape, tf.TensorShape([1, 4, 5]))
self.assertEqual(receiver.features.dtype, tf.float32)
self.assertIsInstance(receiver.receiver_tensors, dict)
# Note that Python 3 can no longer index .values() directly; cast to list.
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape,
tf.TensorShape([1, 4, 5]))
def test_build_tensor_serving_input_receiver_fn_batch_dtype(self):
receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape=[4, 5], dtype=tf.int8, batch_size=10)
with tf.Graph().as_default():
receiver = receiver_fn()
self.assertIsInstance(
receiver, tf.estimator.export.TensorServingInputReceiver)
self.assertIsInstance(receiver.features, tf.Tensor)
self.assertEqual(receiver.features.shape, tf.TensorShape([10, 4, 5]))
self.assertEqual(receiver.features.dtype, tf.int8)
self.assertIsInstance(receiver.receiver_tensors, dict)
# Note that Python 3 can no longer index .values() directly; cast to list.
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape,
tf.TensorShape([10, 4, 5]))
if __name__ == "__main__":
tf.test.main()
# Adding Abseil (absl) flags quickstart
## Defining a flag
absl flag definitions are similar to argparse, although they are defined on a global namespace.
For instance defining a string flag looks like:
```$xslt
from absl import flags
flags.DEFINE_string(
name="my_flag",
default="a_sensible_default",
help="Here is what this flag does."
)
```
All three arguments are required, but default may be `None`. A common optional argument is
short_name for defining abreviations. Certain `DEFINE_*` methods will have other required arguments.
For instance `DEFINE_enum` requires the `enum_values` argument to be specified.
## Key Flags
absl has the concept of a key flag. Any flag defined in `__main__` is considered a key flag by
default. Key flags are displayed in `--help`, others only appear in `--helpfull`. In order to
handle key flags that are defined outside the module in question, absl provides the
`flags.adopt_module_key_flags()` method. This adds the key flags of a different module to one's own
key flags. For example:
```$xslt
File: flag_source.py
---------------------------------------
from absl import flags
flags.DEFINE_string(name="my_flag", default="abc", help="a flag.")
```
```$xslt
File: my_module.py
---------------------------------------
from absl import app as absl_app
from absl import flags
import flag_source
flags.adopt_module_key_flags(flag_source)
def main(_):
pass
absl_app.run(main, [__file__, "-h"]
```
when `my_module.py` is run it will show the help text for `my_flag`. Because not all flags defined
in a file are equally important, `official/utils/flags/core.py` (generally imported as flags_core)
provides an abstraction for handling key flag declaration in an easy way through the
`register_key_flags_in_core()` function, which allows a module to make a single
`adopt_key_flags(flags_core)` call when using the util flag declaration functions.
## Validators
Often the constraints on a flag are complicated. absl provides the validator decorator to allow
one to mark a function as a flag validation function. Suppose we want users to provide a flag
which is a palindrome.
```$xslt
from absl import flags
flags.DEFINE_string(name="pal_flag", short_name="pf", default="", help="Give me a palindrome")
@flags.validator("pal_flag")
def _check_pal(provided_pal_flag):
return provided_pal_flag == provided_pal_flag[::-1]
```
Validators take the form that returning True (truthy) passes, and all others
(False, None, exception) fail.
## Testing
To test using absl, simply declare flags in the setupClass method of TensorFlow's TestCase.
```$xslt
from absl import flags
import tensorflow as tf
def define_flags():
flags.DEFINE_string(name="test_flag", default="abc", help="an example flag")
class BaseTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(BaseTester, cls).setUpClass()
define_flags()
def test_trivial(self):
flags_core.parse_flags([__file__, "test_flag", "def"])
self.AssertEqual(flags.FLAGS.test_flag, "def")
```
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags which will be nearly universal across models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True, batch_size=True,
num_gpu=True, hooks=True, export_dir=True):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if data_dir:
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp",
help=help_wrap("The location of the input data."))
key_flags.append("data_dir")
if model_dir:
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp",
help=help_wrap("The location of the model checkpoint files."))
key_flags.append("model_dir")
if clean:
flags.DEFINE_boolean(
name="clean", default=False,
help=help_wrap("If set, model_dir will be removed if it exists."))
key_flags.append("clean")
if train_epochs:
flags.DEFINE_integer(
name="train_epochs", short_name="te", default=1,
help=help_wrap("The number of epochs used to train."))
key_flags.append("train_epochs")
if epochs_between_evals:
flags.DEFINE_integer(
name="epochs_between_evals", short_name="ebe", default=1,
help=help_wrap("The number of training epochs to run between "
"evaluations."))
key_flags.append("epochs_between_evals")
if stop_threshold:
flags.DEFINE_float(
name="stop_threshold", short_name="st",
default=None,
help=help_wrap("If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."))
if batch_size:
flags.DEFINE_integer(
name="batch_size", short_name="bs", default=32,
help=help_wrap("Batch size for training and evaluation. When using "
"multiple gpus, this is the global batch size for "
"all devices. For example, if the batch size is 32 "
"and there are 4 GPUs, each GPU will get 8 examples on "
"each step."))
key_flags.append("batch_size")
if num_gpu:
flags.DEFINE_integer(
name="num_gpus", short_name="ng",
default=1 if tf.test.is_gpu_available() else 0,
help=help_wrap(
"How many GPUs to use with the DistributionStrategies API. The "
"default is 1 if TensorFlow can detect a GPU, and 0 otherwise."))
if hooks:
# Construct a pretty summary of hooks.
hook_list_str = (
u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key
in hooks_helper.HOOKS]))
flags.DEFINE_list(
name="hooks", short_name="hk", default="LoggingTensorHook",
help=help_wrap(
u"A list of (case insensitive) strings to specify the names of "
u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
u"for details.".format(hook_list_str))
)
key_flags.append("hooks")
if export_dir:
flags.DEFINE_string(
name="export_dir", short_name="ed", default=None,
help=help_wrap("If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links.")
)
key_flags.append("export_dir")
return key_flags
def get_num_gpus(flags_obj):
"""Treat num_gpus=-1 as 'use all'."""
if flags_obj.num_gpus != -1:
return flags_obj.num_gpus
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
local_device_protos = device_lib.list_local_devices()
return sum([1 for d in local_device_protos if d.device_type == "GPU"])
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for benchmarking models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from official.utils.flags._conventions import help_wrap
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"""Register benchmarking flags.
Args:
benchmark_log_dir: Create a flag to specify location for benchmark logging.
bigquery_uploader: Create flags for uploading results to BigQuery.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
flags.DEFINE_enum(
name="benchmark_logger_type", default="BaseBenchmarkLogger",
enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger",
"BenchmarkBigQueryLogger"],
help=help_wrap("The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work."))
flags.DEFINE_string(
name="benchmark_test_id", short_name="bti", default=None,
help=help_wrap("The unique test ID of the benchmark run. It could be the "
"combination of key parameters. It is hardware "
"independent and could be used compare the performance "
"between different test runs. This flag is designed for "
"human consumption, and does not have any impact within "
"the system."))
if benchmark_log_dir:
flags.DEFINE_string(
name="benchmark_log_dir", short_name="bld", default=None,
help=help_wrap("The location of the benchmark logging.")
)
if bigquery_uploader:
flags.DEFINE_string(
name="gcp_project", short_name="gp", default=None,
help=help_wrap(
"The GCP project name where the benchmark will be uploaded."))
flags.DEFINE_string(
name="bigquery_data_set", short_name="bds", default="test_benchmark",
help=help_wrap(
"The Bigquery dataset name where the benchmark will be uploaded."))
flags.DEFINE_string(
name="bigquery_run_table", short_name="brt", default="benchmark_run",
help=help_wrap("The Bigquery table name where the benchmark run "
"information will be uploaded."))
flags.DEFINE_string(
name="bigquery_run_status_table", short_name="brst",
default="benchmark_run_status",
help=help_wrap("The Bigquery table name where the benchmark run "
"status information will be uploaded."))
flags.DEFINE_string(
name="bigquery_metric_table", short_name="bmt",
default="benchmark_metric",
help=help_wrap("The Bigquery table name where the benchmark metric "
"information will be uploaded."))
@flags.multi_flags_validator(
["benchmark_logger_type", "benchmark_log_dir"],
message="--benchmark_logger_type=BenchmarkFileLogger will require "
"--benchmark_log_dir being set")
def _check_benchmark_log_dir(flags_dict):
benchmark_logger_type = flags_dict["benchmark_logger_type"]
if benchmark_logger_type == "BenchmarkFileLogger":
return flags_dict["benchmark_log_dir"]
return True
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Central location for shared arparse convention definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import codecs
import functools
from absl import app as absl_app
from absl import flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
_help_wrap = functools.partial(flags.text_wrap, length=80, indent="",
firstline_indent="\n")
# Pretty formatting causes issues when utf-8 is not installed on a system.
try:
codecs.lookup("utf-8")
help_wrap = _help_wrap
except LookupError:
def help_wrap(text, *args, **kwargs):
return _help_wrap(text, *args, **kwargs).replace("\ufeff", "")
# Replace None with h to also allow -h
absl_app.HelpshortFlag.SHORT_NAME = "h"
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for managing compute devices. Currently only contains TPU flags."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
def require_cloud_storage(flag_names):
"""Register a validator to check directory flags.
Args:
flag_names: An iterable of strings containing the names of flags to be
checked.
"""
msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
@flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
def _path_check(flag_values): # pylint: disable=missing-docstring
if flag_values["tpu"] is None:
return True
valid_flags = True
for key in flag_names:
if not flag_values[key].startswith("gs://"):
tf.logging.error("{} must be a GCS path.".format(key))
valid_flags = False
return valid_flags
def define_device(tpu=True):
"""Register device specific flags.
Args:
tpu: Create flags to specify TPU operation.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if tpu:
flags.DEFINE_string(
name="tpu", default=None,
help=help_wrap(
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a "
"grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
"CPU of the local instance instead. (Good for debugging.)"))
key_flags.append("tpu")
flags.DEFINE_string(
name="tpu_zone", default=None,
help=help_wrap(
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."))
flags.DEFINE_string(
name="tpu_gcp_project", default=None,
help=help_wrap(
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."))
flags.DEFINE_integer(name="num_tpu_shards", default=8,
help=help_wrap("Number of shards (TPU chips)."))
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Register flags for optimizing performance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from absl import flags # pylint: disable=g-bad-import-order
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def get_tf_dtype(flags_obj):
return DTYPE_MAP[flags_obj.dtype][0]
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True,
all_reduce_alg=True, tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu.
tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if num_parallel_calls:
flags.DEFINE_integer(
name="num_parallel_calls", short_name="npc",
default=multiprocessing.cpu_count(),
help=help_wrap("The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"))
if inter_op:
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details.")
)
if intra_op:
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."))
if synthetic_data:
flags.DEFINE_bool(
name="use_synthetic_data", short_name="synth", default=False,
help=help_wrap(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."))
if max_train_steps:
flags.DEFINE_integer(
name="max_train_steps", short_name="mts", default=None, help=help_wrap(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if dtype:
flags.DEFINE_enum(
name="dtype", short_name="dt", default="fp32",
enum_values=DTYPE_MAP.keys(),
help=help_wrap("The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."))
flags.DEFINE_integer(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(
"The amount to scale the loss by when the model is run. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes."))
loss_scale_val_msg = "loss_scale should be a positive integer."
@flags.validator(flag_name="loss_scale", message=loss_scale_val_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
if loss_scale is None:
return True # null case is handled in get_loss_scale()
return loss_scale > 0
if all_reduce_alg:
flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None,
help=help_wrap("Defines the algorithm to use for performing all-reduce."
"See tf.contrib.distribute.AllReduceCrossTowerOps for "
"more details and available options."))
if tf_gpu_thread_mode:
flags.DEFINE_string(
name="tf_gpu_thread_mode", short_name="gt_mode", default=None,
help=help_wrap(
"Whether and how the GPU device uses its own threadpool.")
)
if datasets_num_private_threads:
flags.DEFINE_integer(
name="datasets_num_private_threads",
default=None,
help=help_wrap(
"Number of threads for a private threadpool created for all"
"datasets computation..")
)
if datasets_num_parallel_batches:
flags.DEFINE_integer(
name="datasets_num_parallel_batches",
default=None,
help=help_wrap(
"Determines how many batches to process in parallel when using "
"map and batch from tf.data.")
)
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public interface for flag definition.
See _example.py for detailed instructions on defining flags.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sys
from absl import app as absl_app
from absl import flags
from official.utils.flags import _base
from official.utils.flags import _benchmark
from official.utils.flags import _conventions
from official.utils.flags import _device
from official.utils.flags import _misc
from official.utils.flags import _performance
def set_defaults(**kwargs):
for key, value in kwargs.items():
flags.FLAGS.set_default(name=key, value=value)
def parse_flags(argv=None):
"""Reset flags and reparse. Currently only used in testing."""
flags.FLAGS.unparse_flags()
absl_app.parse_flags_with_usage(argv or sys.argv)
def register_key_flags_in_core(f):
"""Defines a function in core.py, and registers its key flags.
absl uses the location of a flags.declare_key_flag() to determine the context
in which a flag is key. By making all declares in core, this allows model
main functions to call flags.adopt_module_key_flags() on core and correctly
chain key flags.
Args:
f: The function to be wrapped
Returns:
The "core-defined" version of the input function.
"""
def core_fn(*args, **kwargs):
key_flags = f(*args, **kwargs)
[flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned
return core_fn
define_base = register_key_flags_in_core(_base.define_base)
# Remove options not relevant for Eager from define_base().
define_base_eager = register_key_flags_in_core(functools.partial(
_base.define_base, epochs_between_evals=False, stop_threshold=False,
hooks=False))
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image)
define_performance = register_key_flags_in_core(_performance.define_performance)
help_wrap = _conventions.help_wrap
get_num_gpus = _base.get_num_gpus
get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale
DTYPE_MAP = _performance.DTYPE_MAP
require_cloud_storage = _device.require_cloud_storage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core # pylint: disable=g-bad-import-order
def define_flags():
flags_core.define_base(num_gpu=False)
flags_core.define_performance()
flags_core.define_image()
flags_core.define_benchmark()
class BaseTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(BaseTester, cls).setUpClass()
define_flags()
def test_default_setting(self):
"""Test to ensure fields exist and defaults can be set.
"""
defaults = dict(
data_dir="dfgasf",
model_dir="dfsdkjgbs",
train_epochs=534,
epochs_between_evals=15,
batch_size=256,
hooks=["LoggingTensorHook"],
num_parallel_calls=18,
inter_op_parallelism_threads=5,
intra_op_parallelism_threads=10,
data_format="channels_first"
)
flags_core.set_defaults(**defaults)
flags_core.parse_flags()
for key, value in defaults.items():
assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_benchmark_setting(self):
defaults = dict(
hooks=["LoggingMetricHook"],
benchmark_log_dir="/tmp/12345",
gcp_project="project_abc",
)
flags_core.set_defaults(**defaults)
flags_core.parse_flags()
for key, value in defaults.items():
assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_booleans(self):
"""Test to ensure boolean flags trigger as expected.
"""
flags_core.parse_flags([__file__, "--use_synthetic_data"])
assert flags.FLAGS.use_synthetic_data
def test_parse_dtype_info(self):
for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
["fp32", tf.float32, 1]]:
flags_core.parse_flags([__file__, "--dtype", dtype_str])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"])
if __name__ == "__main__":
unittest.main()
# Using flags in official models
1. **All common flags must be incorporated in the models.**
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
and channeled through `official.utils.flags.core`. For instance to define common supervised
learning parameters one could use the following code:
```$xslt
from absl import app as absl_app
from absl import flags
from official.utils.flags import core as flags_core
def define_flags():
flags_core.define_base()
flags.adopt_key_flags(flags_core)
def main(_):
flags_obj = flags.FLAGS
print(flags_obj)
if __name__ == "__main__"
absl_app.run(main)
```
2. **Validate flag values.**
See the [Validators](#validators) section for implementation details.
Validators in the official model repo should not access the file system, such as verifying
that files exist, due to the strict ordering requirements.
3. **Flag values should not be mutated.**
Instead of mutating flag values, use getter functions to return the desired values. An example
getter function is `get_loss_scale` function below:
```
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.loss_scale is None:
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this
print(get_loss_scale(flags_obj))
...
```
\ No newline at end of file
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities that interact with cloud service.
"""
import requests
GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname"
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"}
def on_gcp():
"""Detect whether the current running environment is on GCP."""
try:
# Timeout in 5 seconds, in case the test environment has connectivity issue.
# There is not default timeout, which means it might block forever.
response = requests.get(
GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
# Logging in official models
This library adds logging functions that print or save tensor values. Official models should define all common hooks
(using hooks helper) and a benchmark logger.
1. **Training Hooks**
Hooks are a TensorFlow concept that define specific actions at certain points of the execution. We use them to obtain and log
tensor values during training.
hooks_helper.py provides an easy way to create common hooks. The following hooks are currently defined:
* LoggingTensorHook: Logs tensor values
* ProfilerHook: Writes a timeline json that can be loaded into chrome://tracing.
* ExamplesPerSecondHook: Logs the number of examples processed per second.
* LoggingMetricHook: Similar to LoggingTensorHook, except that the tensors are logged in a format defined by our data
anaylsis pipeline.
2. **Benchmarks**
The benchmark logger provides useful functions for logging environment information, and evaluation results.
The module also contains a context which is used to update the status of the run.
Example usage:
```
from absl import app as absl_app
from official.utils.logs import hooks_helper
from official.utils.logs import logger
def model_main(flags_obj):
estimator = ...
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(...)
train_hooks = hooks_helper.get_train_hooks(...)
for epoch in range(10):
estimator.train(..., hooks=train_hooks)
eval_results = estimator.evaluate(...)
# Log a dictionary of metrics
benchmark_logger.log_evaluation_result(eval_results)
# Log an individual metric
benchmark_logger.log_metric(...)
def main(_):
with logger.benchmark_context(flags.FLAGS):
model_main(flags.FLAGS)
if __name__ == "__main__":
# define flags
absl_app.run(main)
```
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hook that counts examples per second every N steps or seconds."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import logger
class ExamplesPerSecondHook(tf.compat.v1.train.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(self,
batch_size,
every_n_steps=None,
every_n_secs=None,
warm_steps=0,
metric_logger=None):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size across all workers used to calculate
examples/second from global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. Exactly one of the
`every_n_steps` or `every_n_secs` should be set.
warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all
workers, not on each worker
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. If None, BaseBenchmarkLogger will
be used.
Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
both are set.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError("exactly one of every_n_steps"
" and every_n_secs should be provided.")
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._timer = tf.train.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
self._warm_steps = warm_steps
def begin(self):
"""Called once before using the session to check global step."""
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run().
Args:
run_context: A SessionRunContext object.
Returns:
A SessionRunArgs object or None if never triggered.
"""
return tf.train.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run().
Args:
run_context: A SessionRunContext object.
run_values: A SessionRunValues object.
"""
global_step = run_values.results
if self._timer.should_trigger_for_step(
global_step) and global_step > self._warm_steps:
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
# average examples per second is based on the total (accumulative)
# training steps and training time so far
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
# current examples per second is based on the elapsed training steps
# and training time per batch
current_examples_per_sec = self._batch_size * (
elapsed_steps / elapsed_time)
self._logger.log_metric(
"average_examples_per_sec", average_examples_per_sec,
global_step=global_step)
self._logger.log_metric(
"current_examples_per_sec", current_examples_per_sec,
global_step=global_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment