Commit 555722af authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323668407
parent ab2c8ba9
......@@ -96,6 +96,30 @@ def create_tf_while_loop_fn(step_fn):
return loop_fn
def create_global_step() -> tf.Variable:
"""Creates a `tf.Variable` suitable for use as a global step counter.
Creating and managing a global step variable may be necessary for
`AbstractTrainer` subclasses that perform multiple parameter updates per
`Controller` "step", or use different optimizers on different steps.
In these cases, an `optimizer.iterations` property generally can't be used
directly, since it would correspond to parameter updates instead of iterations
in the `Controller`'s training loop. Such use cases should simply call
`step.assign_add(1)` at the end of each step.
Returns:
A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the
first replica's value retained when synchronizing across replicas in
a distributed setting.
"""
return tf.Variable(
0,
dtype=tf.int64,
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A helper function to create distributed dataset.
......@@ -186,13 +210,13 @@ class SummaryManager:
def write_summaries(self, summary_dict):
"""Write summaries for the given values.
This recursively creates sub-directories for any nested dictionaries
This recursively creates subdirectories for any nested dictionaries
provided in `summary_dict`, yielding a hierarchy of directories which will
then be reflected in the TensorBoard UI as different colored curves.
E.g. users may evaluate on muliple datasets and return `summary_dict` as a
nested
dictionary.
nested dictionary.
```
{
"dataset": {
......@@ -205,9 +229,10 @@ class SummaryManager:
},
}
```
It will create two sub directories "dataset" and "dataset2" inside summary
root directory. And each directory write both "loss" and "accuracy"
summaries inside.
This will create two subdirectories "dataset" and "dataset2" inside the
summary root directory. Each directory will contain event files including
both "loss" and "accuracy" summaries.
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.utils."""
from orbit import utils
import tensorflow as tf
class UtilsTest(tf.test.TestCase):
def test_create_global_step(self):
step = utils.create_global_step()
self.assertEqual(step.dtype, tf.int64)
self.assertEqual(step, 0)
step.assign_add(1)
self.assertEqual(step, 1)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment