"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "6373534ad4551a96cc45ce447c609ce62f2f695e"
Commit 79354e14 authored by Dan Holtmann-Rice's avatar Dan Holtmann-Rice Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334528962
parent 8f345563
...@@ -12,9 +12,16 @@ ...@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Orbit package definition.""" """Defines exported symbols for `orbit` package."""
from orbit import utils from orbit import utils
from orbit.controller import Controller from orbit.controller import Controller
from orbit.runner import *
from orbit.standard_runner import * from orbit.runner import AbstractEvaluator
from orbit.runner import AbstractTrainer
from orbit.standard_runner import StandardEvaluator
from orbit.standard_runner import StandardEvaluatorOptions
from orbit.standard_runner import StandardTrainer
from orbit.standard_runner import StandardTrainerOptions
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
"""A light weight utilities to train TF2 models.""" """A light weight utilities to train TF2 models."""
import time import time
from typing import Callable, Dict, Optional, Text, Union from typing import Callable, Dict, Optional, Text, Union
from absl import logging from absl import logging
import numpy as np import numpy as np
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
......
...@@ -12,13 +12,17 @@ ...@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """AbstractTrainer/Evaluator implementations for standard settings."""
import abc import abc
from typing import Any, Dict, Optional, Text from typing import Any, Dict, Optional, Text
import dataclasses import dataclasses
from orbit import runner from orbit import runner
from orbit import utils from orbit.utils import loop_fns
import tensorflow as tf import tensorflow as tf
...@@ -45,6 +49,21 @@ class StandardTrainerOptions: ...@@ -45,6 +49,21 @@ class StandardTrainerOptions:
use_tpu_summary_optimization: bool = False use_tpu_summary_optimization: bool = False
def _create_train_loop_fn(train_step_fn, options: StandardTrainerOptions):
"""Creates a training loop from the given step function and options."""
if options.use_tf_while_loop:
loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
if options.use_tpu_summary_optimization:
loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
else:
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
train_step_fn = tf.function(train_step_fn)
loop_fn = loop_fns.create_loop_fn(train_step_fn)
return loop_fn
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
...@@ -64,36 +83,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -64,36 +83,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
raise ValueError("`use_tpu_summary_optimization=True` and " raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported") "`use_tf_while_loop=False` is not supported")
self._use_tf_while_loop = options.use_tf_while_loop self._train_options = options
self._use_tf_function = options.use_tf_function
self._use_tpu_summary_optimization = options.use_tpu_summary_optimization
self._train_dataset = train_dataset self._train_dataset = train_dataset
self._train_iter = None self._train_iter = None
self._train_loop_fn = None self._train_loop_fn = None
def train(self, def train(
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: self,
num_steps: Optional[tf.Tensor],
) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class.""" """See base class."""
self.train_loop_begin() self.train_loop_begin()
if self._train_loop_fn is None:
self._train_loop_fn = _create_train_loop_fn(
self.train_step, options=self._train_options)
if self._train_iter is None: if self._train_iter is None:
self._train_iter = tf.nest.map_structure(iter, self.train_dataset) self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
if self._train_loop_fn is None:
train_fn = self.train_step
if self._use_tf_while_loop:
self._train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
if self._use_tpu_summary_optimization:
self._train_loop_fn = utils.train_function_with_summaries(
self._train_loop_fn)
else:
self._train_loop_fn = tf.function(self._train_loop_fn)
else:
if self._use_tf_function:
train_fn = tf.function(train_fn)
self._train_loop_fn = utils.create_loop_fn(train_fn)
self._train_loop_fn(self._train_iter, num_steps) self._train_loop_fn(self._train_iter, num_steps)
return self.train_loop_end() return self.train_loop_end()
...@@ -172,6 +180,12 @@ class StandardEvaluatorOptions: ...@@ -172,6 +180,12 @@ class StandardEvaluatorOptions:
use_tf_function: bool = True use_tf_function: bool = True
def _create_eval_loop_fn(eval_step_fn, options: StandardEvaluatorOptions):
if options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
return loop_fns.create_loop_fn(eval_step_fn)
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
...@@ -183,25 +197,25 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -183,25 +197,25 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
DistributedDataset. DistributedDataset.
options: An `orbit.StandardEvaluatorOptions` instance. options: An `orbit.StandardEvaluatorOptions` instance.
""" """
options = options or StandardEvaluatorOptions() self._eval_options = options or StandardEvaluatorOptions()
self._eval_use_tf_function = options.use_tf_function
self._eval_dataset = eval_dataset self._eval_dataset = eval_dataset
self._eval_loop_fn = None self._eval_loop_fn = None
def evaluate( def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: self,
num_steps: Optional[tf.Tensor],
) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class.""" """See base class."""
outputs = self.eval_begin() # pylint: disable=assignment-from-no-return outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
if self._eval_loop_fn is None: if self._eval_loop_fn is None:
eval_fn = self.eval_step self._eval_loop_fn = _create_eval_loop_fn(
if self._eval_use_tf_function: self.eval_step, options=self._eval_options)
eval_fn = tf.function(eval_fn)
self._eval_loop_fn = utils.create_loop_fn(eval_fn)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
outputs = self._eval_loop_fn( outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce) eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
if outputs is None: if outputs is None:
return self.eval_end() return self.eval_end()
else: else:
......
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
import abc
import contextlib
import functools
import inspect
import os
import numpy as np
import tensorflow as tf
def create_loop_fn(step_fn):
"""Creates a multiple steps function driven by the python while loop.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps==-1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
`value` is the outputs from `step_fn`.
Returns:
The updated state.
"""
try:
step = 0
# To make sure the OutOfRangeError exception can be handled well with
# async remote eager, we need to wrap the loop body in a `async_scope`.
with tf.experimental.async_scope():
while (num_steps == -1 or step < num_steps):
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Create a multiple steps function driven by tf.while_loop on the host.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Must be a tf.Tensor.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError("`num_steps` should be an `tf.Tensor`. Python object "
"may cause retracing.")
for _ in tf.range(num_steps):
step_fn(iterator)
return loop_fn
def create_global_step() -> tf.Variable:
"""Creates a `tf.Variable` suitable for use as a global step counter.
Creating and managing a global step variable may be necessary for
`AbstractTrainer` subclasses that perform multiple parameter updates per
`Controller` "step", or use different optimizers on different steps.
In these cases, an `optimizer.iterations` property generally can't be used
directly, since it would correspond to parameter updates instead of iterations
in the `Controller`'s training loop. Such use cases should simply call
`step.assign_add(1)` at the end of each step.
Returns:
A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the
first replica's value retained when synchronizing across replicas in
a distributed setting.
"""
return tf.Variable(
0,
dtype=tf.int64,
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A helper function to create distributed dataset.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
`tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
it is a function, it could optionally have an argument named
`input_context` which is `tf.distribute.InputContext` argument type.
*args: The list of arguments to be passed to dataset_or_fn.
**kwargs: Any keyword arguments to be passed.
Returns:
A distributed Dataset.
"""
if strategy is None:
strategy = tf.distribute.get_strategy()
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`")
def dataset_fn(ctx):
"""Wrapped dataset function for creating distributed dataset.."""
# If `dataset_or_fn` is a function and has `input_context` as argument
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
argspec = inspect.getfullargspec(dataset_or_fn)
args_names = argspec.args
if "input_context" in args_names:
kwargs["input_context"] = ctx
ds = dataset_or_fn(*args, **kwargs)
return ds
return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager:
"""A class manages writing summaries."""
def __init__(self, summary_dir, summary_fn, global_step=None):
"""Construct a summary manager object.
Args:
summary_dir: the directory to write summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for the global step.
"""
self._enabled = (summary_dir is not None)
self._summary_dir = summary_dir
self._summary_fn = summary_fn
self._summary_writers = {}
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
def summary_writer(self, relative_path=""):
"""Returns the underlying summary writer.
Args:
relative_path: The current path in which to write summaries, relative to
the summary directory. By default it is empty, which specifies the root
directory.
"""
if self._summary_writers and relative_path in self._summary_writers:
return self._summary_writers[relative_path]
if self._enabled:
self._summary_writers[relative_path] = tf.summary.create_file_writer(
os.path.join(self._summary_dir, relative_path))
else:
self._summary_writers[relative_path] = tf.summary.create_noop_writer()
return self._summary_writers[relative_path]
def flush(self):
"""Flush the underlying summary writers."""
if self._enabled:
tf.nest.map_structure(tf.summary.flush, self._summary_writers)
def write_summaries(self, summary_dict):
"""Write summaries for the given values.
This recursively creates subdirectories for any nested dictionaries
provided in `summary_dict`, yielding a hierarchy of directories which will
then be reflected in the TensorBoard UI as different colored curves.
E.g. users may evaluate on muliple datasets and return `summary_dict` as a
nested dictionary.
```
{
"dataset": {
"loss": loss,
"accuracy": accuracy
},
"dataset2": {
"loss": loss2,
"accuracy": accuracy2
},
}
```
This will create two subdirectories "dataset" and "dataset2" inside the
summary root directory. Each directory will contain event files including
both "loss" and "accuracy" summaries.
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
itself a dictionary, then the function will recursively create
subdirectories with names given by the keys in the dictionary. The
Tensor values are summarized using the summary writer instance specific
to the parent relative path.
"""
if not self._enabled:
return
self._write_summaries(summary_dict)
def _write_summaries(self, summary_dict, relative_path=""):
for name, value in summary_dict.items():
if isinstance(value, dict):
self._write_summaries(
value, relative_path=os.path.join(relative_path, name))
else:
with self.summary_writer(relative_path).as_default():
self._summary_fn(name, value, step=self._global_step)
class Trigger(metaclass=abc.ABCMeta):
"""An abstract class representing a "trigger" for some event."""
@abc.abstractmethod
def __call__(self, value: float, force_trigger=False):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: Whether the trigger is forced triggered.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
@abc.abstractmethod
def reset(self):
"""Reset states in the trigger."""
class IntervalTrigger(Trigger):
"""Triggers on every fixed interval."""
def __init__(self, interval, start=0):
"""Constructs the IntervalTrigger.
Args:
interval: The triggering interval.
start: An initial value for the trigger.
"""
self._interval = interval
self._last_trigger_value = start
def __call__(self, value, force_trigger=False):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: If True, the trigger will be forced triggered unless the
last trigger value is equal to `value`.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
if force_trigger and value != self._last_trigger_value:
self._last_trigger_value = value
return True
if self._interval and self._interval > 0:
if value >= self._last_trigger_value + self._interval:
self._last_trigger_value = value
return True
return False
def reset(self):
"""See base class."""
self._last_trigger_value = 0
class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step):
"""Constructs the EpochHelper.
Args:
epoch_steps: An integer indicates how many steps in an epoch.
global_step: A `tf.Variable` instance indicates the current global step.
"""
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True
return True
def epoch_end(self):
"""Returns whether the current epoch should end."""
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy()
epoch = current_step // self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property
def current_epoch(self):
return self._current_epoch
@contextlib.contextmanager
def _soft_device_placement():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting = tf.config.get_soft_device_placement()
try:
tf.config.set_soft_device_placement(True)
yield
finally:
tf.config.set_soft_device_placement(original_setting)
def train_function_with_summaries(*args, **kwargs):
"""Utility function to support TPU summaries via multiple `tf.function`s.
This permits interleaving summaries inside TPU-compatible code, but without
any performance impact on steps that do not write summaries.
Usage is as a decorator, similar to `tf.function`, and any `tf.function`
arguments will be passed through if supplied:
@trainer.train_function_with_summaries
def train(self, num_steps):
...
The decorated function is assumed to be a loop method accepting a `num_steps`
parameter, as for instance would be called within the `Controller`'s outer
train loop. The implementation here assumes that `summary_frequency` is
divisible by `steps_per_loop`. The decorated method should accept two
arguments, `self` and `num_steps`.
Two `tf.function` versions of `train_fn` are created: one inside a summary
writer scope with soft device placement enabled (used on steps that require
summary writing), and one with no summary writer present and soft device
placement disabled (used on all other steps).
Args:
*args: Arguments to pass through to `tf.function`.
**kwargs: Keyword arguments to pass through to `tf.function`.
Returns:
If the first argument is a callable, returns the decorated callable.
Otherwise, returns a decorator.
"""
def decorator(train_fn):
# TODO(dhr): Validate the signature of train_fn?
train_fn_with_summaries = tf.function(train_fn, *args, **kwargs)
train_fn_without_summaries = tf.function(train_fn, *args, **kwargs)
@functools.wraps(train_fn)
def wrapper(self, num_steps):
if tf.summary.should_record_summaries():
with _soft_device_placement():
output = train_fn_with_summaries(self, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
with tf.summary.record_if(False):
output = train_fn_without_summaries(self, num_steps)
return output
return wrapper
if args and callable(args[0]):
train_fn, args = args[0], args[1:]
return decorator(train_fn)
return decorator
def get_value(x) -> np.number:
"""Returns the value of a variable/tensor.
Args:
x: input variable.
Returns:
A Numpy array or number.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines exported symbols for `orbit.utils` package."""
from orbit.utils.common import create_global_step
from orbit.utils.common import get_value
from orbit.utils.common import make_distributed_dataset
from orbit.utils.epoch_helper import EpochHelper
from orbit.utils.loop_fns import create_loop_fn
from orbit.utils.loop_fns import create_tf_while_loop_fn
from orbit.utils.summary_manager import SummaryManager
from orbit.utils.tpu_summaries import OptionalSummariesFunction
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
import inspect
import numpy as np
import tensorflow as tf
def create_global_step() -> tf.Variable:
"""Creates a `tf.Variable` suitable for use as a global step counter.
Creating and managing a global step variable may be necessary for
`AbstractTrainer` subclasses that perform multiple parameter updates per
`Controller` "step", or use different optimizers on different steps.
In these cases, an `optimizer.iterations` property generally can't be used
directly, since it would correspond to parameter updates instead of iterations
in the `Controller`'s training loop. Such use cases should simply call
`step.assign_add(1)` at the end of each step.
Returns:
A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the
first replica's value retained when synchronizing across replicas in
a distributed setting.
"""
return tf.Variable(
0,
dtype=tf.int64,
name="global_step",
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A helper function to create distributed dataset.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
`tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
it is a function, it could optionally have an argument named
`input_context` which is `tf.distribute.InputContext` argument type.
*args: The list of arguments to be passed to dataset_or_fn.
**kwargs: Any keyword arguments to be passed.
Returns:
A distributed Dataset.
"""
if strategy is None:
strategy = tf.distribute.get_strategy()
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`")
def dataset_fn(ctx):
"""Wrapped dataset function for creating distributed dataset.."""
# If `dataset_or_fn` is a function and has `input_context` as argument
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
argspec = inspect.getfullargspec(dataset_or_fn)
args_names = argspec.args
if "input_context" in args_names:
kwargs["input_context"] = ctx
ds = dataset_or_fn(*args, **kwargs)
return ds
return strategy.experimental_distribute_datasets_from_function(dataset_fn)
def get_value(x) -> np.number:
"""Returns the value of a variable/tensor.
Args:
x: input variable.
Returns:
A Numpy array or number.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for orbit.utils.""" """Tests for orbit.utils.common."""
from orbit import utils from orbit.utils import common
import tensorflow as tf import tensorflow as tf
...@@ -22,12 +22,13 @@ import tensorflow as tf ...@@ -22,12 +22,13 @@ import tensorflow as tf
class UtilsTest(tf.test.TestCase): class UtilsTest(tf.test.TestCase):
def test_create_global_step(self): def test_create_global_step(self):
step = utils.create_global_step() step = common.create_global_step()
self.assertEqual(step.name, "global_step:0")
self.assertEqual(step.dtype, tf.int64) self.assertEqual(step.dtype, tf.int64)
self.assertEqual(step, 0) self.assertEqual(step, 0)
step.assign_add(1) step.assign_add(1)
self.assertEqual(step, 1) self.assertEqual(step, 1)
if __name__ == '__main__': if __name__ == "__main__":
tf.test.main() tf.test.main()
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides a utility class for training in epochs."""
import tensorflow as tf
class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps: int, global_step: tf.Variable):
"""Constructs the EpochHelper.
Args:
epoch_steps: An integer indicates how many steps in an epoch.
global_step: A `tf.Variable` instance indicates the current global step.
"""
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True
return True
def epoch_end(self):
"""Returns whether the current epoch should end."""
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy()
epoch = current_step // self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property
def current_epoch(self):
return self._current_epoch
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for creating loop functions."""
from orbit.utils import tpu_summaries
import tensorflow as tf
def create_loop_fn(step_fn):
"""Creates a multiple steps function driven by the python while loop.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps==-1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
`value` is the outputs from `step_fn`.
Returns:
The updated state.
"""
try:
step = 0
# To make sure the OutOfRangeError exception can be handled well with
# async remote eager, we need to wrap the loop body in a `async_scope`.
with tf.experimental.async_scope():
while (num_steps == -1 or step < num_steps):
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Create a multiple steps function driven by tf.while_loop on the host.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Must be a tf.Tensor.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError("`num_steps` should be an `tf.Tensor`. Python object "
"may cause retracing.")
for _ in tf.range(num_steps):
step_fn(iterator)
return loop_fn
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for optimizing summaries on TPU.
This version works with the result of `create_tf_while_loop_fn`.
"""
def __call__(self, iterator, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(iterator, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(iterator, num_steps)
return output
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides a utility class for managing summary writing."""
import os
import tensorflow as tf
class SummaryManager:
"""A class manages writing summaries."""
def __init__(self, summary_dir, summary_fn, global_step=None):
"""Construct a summary manager object.
Args:
summary_dir: the directory to write summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for the global step.
"""
self._enabled = (summary_dir is not None)
self._summary_dir = summary_dir
self._summary_fn = summary_fn
self._summary_writers = {}
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
def summary_writer(self, relative_path=""):
"""Returns the underlying summary writer.
Args:
relative_path: The current path in which to write summaries, relative to
the summary directory. By default it is empty, which specifies the root
directory.
"""
if self._summary_writers and relative_path in self._summary_writers:
return self._summary_writers[relative_path]
if self._enabled:
self._summary_writers[relative_path] = tf.summary.create_file_writer(
os.path.join(self._summary_dir, relative_path))
else:
self._summary_writers[relative_path] = tf.summary.create_noop_writer()
return self._summary_writers[relative_path]
def flush(self):
"""Flush the underlying summary writers."""
if self._enabled:
tf.nest.map_structure(tf.summary.flush, self._summary_writers)
def write_summaries(self, summary_dict):
"""Write summaries for the given values.
This recursively creates subdirectories for any nested dictionaries
provided in `summary_dict`, yielding a hierarchy of directories which will
then be reflected in the TensorBoard UI as different colored curves.
E.g. users may evaluate on muliple datasets and return `summary_dict` as a
nested dictionary.
```
{
"dataset": {
"loss": loss,
"accuracy": accuracy
},
"dataset2": {
"loss": loss2,
"accuracy": accuracy2
},
}
```
This will create two subdirectories "dataset" and "dataset2" inside the
summary root directory. Each directory will contain event files including
both "loss" and "accuracy" summaries.
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
itself a dictionary, then the function will recursively create
subdirectories with names given by the keys in the dictionary. The
Tensor values are summarized using the summary writer instance specific
to the parent relative path.
"""
if not self._enabled:
return
self._write_summaries(summary_dict)
def _write_summaries(self, summary_dict, relative_path=""):
for name, value in summary_dict.items():
if isinstance(value, dict):
self._write_summaries(
value, relative_path=os.path.join(relative_path, name))
else:
with self.summary_writer(relative_path).as_default():
self._summary_fn(name, value, step=self._global_step)
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains utilities for TPU summary optimization."""
import contextlib
import functools
import tensorflow as tf
@contextlib.contextmanager
def _soft_device_placement():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting = tf.config.get_soft_device_placement()
try:
tf.config.set_soft_device_placement(True)
yield
finally:
tf.config.set_soft_device_placement(original_setting)
class OptionalSummariesFunction:
"""Wrapper that provides versions of a function with and without summaries.
This is a utility class for implementing optimized summary recording via a
two-function approach, specifically important for TPUs. Two `tf.function`
versions of a given `function` are created: one with soft device placement
enabled (for use on steps that require summary writing), and one with summary
writing and soft device placement entirely disabled (for use on all other
steps). This removes any performance impact of summaries on steps where they
aren't recorded (b/148418718).
This class can be used as a base class to implement summary optimizations for
a function with a specific signature. For example, to implement efficient TPU
summaries for a standard `train()` method (as in `orbit.AbstractTrainer`):
class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):
'''Implements a two-program approach for summaries on TPU.'''
def __call__(self, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(num_steps)
return output
This can be used directly or to implement a decorator:
def train_function_with_summaries(function=None, **kwargs):
if function is not None:
return TrainFunctionWithSummaries(function, **kwargs)
return functools.partial(TrainFunctionWithSummaries, **kwargs)
The director can be applied directly to `train()` methods:
@train_function_with_summaries
def train(self, num_steps):
...
A similar approach approach can be implemented for functions with different
signatures.
Note: The above approach assumes that the frequency of summary writing is
based on a step interval that is divisible by the number of steps executed
in each call to the `train()` function. This is enforced by the
`orbit.Controller`.
This wrapper properly handles instance methods (see `__get__`).
Attributes:
with_summaries: A wrapped version of the underlying function with summaries
enabled (using whatever the active predicate is for
`tf.summary.record_if`), and placed inside a "soft device placement"
context to enable summary recording on TPU.
without_summaries: A wrapped version of the underlying function with all
summary recording disabled.
"""
def __init__(self, function, **tf_function_kwargs):
"""Constructs an instance wrapping the given `function`.
The given `function` is wrapped twice: Once in a "soft device placement"
context (allowing summaries to also run on TPU), and once with summary
recording entirely disabled.
Both of these versions are compiled via `tf.function` (optionally using any
supplied `tf.function` settings), and made available as attributes.
Args:
function: The underlying function to wrap.
**tf_function_kwargs: Additional arguments to pass to `tf.function`.
"""
@tf.function(**tf_function_kwargs)
@functools.wraps(function)
def with_summaries(*args, **kwargs):
with _soft_device_placement():
return function(*args, **kwargs)
@tf.function(**tf_function_kwargs)
@functools.wraps(function)
def without_summaries(*args, **kwargs):
with tf.summary.record_if(False):
return function(*args, **kwargs)
self.with_summaries = with_summaries
self.without_summaries = without_summaries
def __get__(self, instance, owner):
"""Allows this class to be used to wrap methods as well as free functions.
For `tf.function` to work properly in all cases (e.g., when an
input_signature is specified), any `tf.function`-converted methods must be
properly bound to an instance if they are called as an instance method.
This is done by implementing this `__get__` method of the descriptor
protocol, and forwarding to the `__get__` method on the underlying
`tf.function`s.
Args:
instance: The instance to bind to.
owner: The class type of the instance.
Returns:
A new bound instance of `TpuDiscretionarySummariesFunctions`.
"""
new = object.__new__(self.__class__)
# pytype: disable=attribute-error # See b/162476201.
new.with_summaries = self.with_summaries.__get__(instance, owner)
new.without_summaries = self.without_summaries.__get__(instance, owner)
# pytype: enable=attribute-error
return new
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.utils.tpu_summaries."""
import functools
import os
from orbit.utils import common
from orbit.utils import tpu_summaries
import tensorflow as tf
class TrainFunctionWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for summaries on TPU."""
def __call__(self, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(num_steps)
return output
def train_function_with_summaries(function=None, **kwargs):
if function is not None:
return TrainFunctionWithSummaries(function, **kwargs)
return functools.partial(TrainFunctionWithSummaries, **kwargs)
class DummyTrainer(tf.Module):
def __init__(self):
self.step_counter = common.create_global_step()
@train_function_with_summaries
def train_with_tpu_summary_optimization(self, num_steps):
for _ in tf.range(num_steps):
tf.summary.scalar("step", self.step_counter, step=self.step_counter)
self.step_counter.assign_add(1)
return self.step_counter
@train_function_with_summaries(
input_signature=[tf.TensorSpec((), dtype=tf.int32)])
def train_with_tpu_summary_optimization_and_input_signature(self, num_steps):
for _ in tf.range(num_steps):
tf.summary.scalar("step", self.step_counter, step=self.step_counter)
self.step_counter.assign_add(1)
return self.step_counter
def train_with_tpu_summary_optimization_no_decorator(self, num_steps):
for _ in tf.range(num_steps):
tf.summary.scalar("step", self.step_counter, step=self.step_counter)
self.step_counter.assign_add(1)
return self.step_counter
class TpuSummariesTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.trainer = DummyTrainer()
def _get_events_from_logdir(self, logdir):
event_files = tf.io.gfile.listdir(logdir)
self.assertLen(event_files, 1)
path = os.path.join(logdir, event_files[0])
events = list(tf.compat.v1.train.summary_iterator(path))
return [event for event in events if event.WhichOneof("what") == "summary"]
def _validate_tpu_summary_optimization(self, function, *args, **kwargs):
logdir = self.get_temp_dir()
with tf.summary.create_file_writer(logdir).as_default():
with tf.summary.record_if(lambda: self.trainer.step_counter % 20 == 0):
for _ in range(4):
output = function(tf.constant(10), *args, **kwargs)
events = self._get_events_from_logdir(logdir)
self.assertLen(events, 2)
self.assertEqual(events[0].step, 0)
self.assertEqual(events[1].step, 20)
return output
def test_train_with_tpu_summary_optimization(self):
output = self._validate_tpu_summary_optimization(
self.trainer.train_with_tpu_summary_optimization)
self.assertEqual(output, self.trainer.step_counter.numpy())
def test_train_with_tpu_summary_optimization_no_decorator(self):
optimized = train_function_with_summaries(
self.trainer.train_with_tpu_summary_optimization_no_decorator)
output = self._validate_tpu_summary_optimization(optimized)
self.assertEqual(output, self.trainer.step_counter.numpy())
def test_train_with_tpu_summary_optimization_and_input_signature(self):
output = self._validate_tpu_summary_optimization(
self.trainer.train_with_tpu_summary_optimization_and_input_signature)
self.assertEqual(output, self.trainer.step_counter.numpy())
function = self.trainer.train_with_tpu_summary_optimization_and_input_signature
expected = (tf.TensorSpec((), dtype=tf.int32),)
input_signature = function.with_summaries.input_signature
self.assertEqual(input_signature, expected)
input_signature = function.without_summaries.input_signature
self.assertEqual(input_signature, expected)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment