Commit f2882f6e authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

First pass of a custom training loop library and applies it in model garden Resnet.

PiperOrigin-RevId: 295231206
parent e33fb967
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time
from absl import logging
import tensorflow.compat.v2 as tf
from typing import Callable, Dict, Optional, Text
from official.staging.training import utils
class Controller(object):
"""Class that facilitates training and evaluation of models."""
def __init__(
self,
strategy: Optional[tf.distribute.Strategy] = None,
train_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
eval_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
global_step: Optional[tf.Variable] = None,
# Train related
train_steps: Optional[int] = None,
steps_per_loop: Optional[int] = None,
summary_dir: Optional[Text] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# summary related
summary_interval: Optional[int] = None,
# Evaluation related
eval_summary_dir: Optional[Text] = None,
eval_steps: Optional[int] = None,
eval_interval: Optional[int] = None):
"""Constructs a `Controller` instance.
Args:
strategy: An instance of `tf.distribute.Strategy`.
train_fn: A callable defined as `def train_fn(num_steps)`, which
`num_steps` indicates the number of steps to run for each loop.
eval_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps`
indicates the number of steps for one evaluation.
global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can
create their own global step variable as well. If the users create their
own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
train_steps: The total (maximum) number of training steps to perform.
steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `train_fn`).
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries outside the training loop. If the
value is None, then training summaries are not enabled.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip
evaluation. Note that evaluation only happens outside the training loop,
which the loop iteration is specify by `steps_per_loop` parameter.
Raises:
ValueError: If both `train_fn` and `eval_fn` are None.
ValueError: If `train_fn` is not None and `train_steps` is None.
ValueError: If `steps_per_loop` is None when `train_fn` is provided.
ValueError: If `steps_per_loop` is not a positive integer.
"""
if train_fn is None and eval_fn is None:
raise ValueError("`train_fn` and `eval_fn` should not both be None")
# TODO(rxsang): Support training until exhaustion by passing
# `train_steps=-1`. Currently it cannot be supported with a host training
# loop because break statements are not supported with distributed dataset.
if train_fn is not None and train_steps is None:
raise ValueError("`train_steps` is required when `train_fn` is provided.")
if train_fn is not None and steps_per_loop is None:
raise ValueError("`steps_per_loop` is required when `train_fn is "
"provided.")
if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
raise ValueError("`steps_per_loop` should be a positive integer")
if summary_interval is not None and summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
self.strategy = strategy or tf.distribute.get_strategy()
self.train_fn = train_fn
self.eval_fn = eval_fn
self.global_step = global_step
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory
self.checkpoint_manager = checkpoint_manager
self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer(
self.summary_dir) if self.summary_interval else None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
summary_writer,
tf.summary.scalar,
global_step=self.global_step,
summary_interval=self.summary_interval)
if self.global_step:
tf.summary.experimental.set_step(self.global_step)
self.eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(self.eval_summary_dir)
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps
self.eval_interval = eval_interval
# Restore Model if needed.
if self.checkpoint_manager is not None:
model_restored = self._restore_model()
if not model_restored and self.checkpoint_manager.checkpoint_interval:
# If the model is not restored from a checkpoint, save an initial
# checkpoint.
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path)
# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
def _restore_model(self, checkpoint_path=None):
"""Restore or initialize the model.
Args:
checkpoint_path: An optional string indicates the checkpoint path to
restore. If None, will restore from `self.checkpoint_manager`.
Returns:
True if the latest checkpoint is found or restored. Otherwise False.
"""
with self.strategy.scope():
# Checkpoint restoring should be inside scope. b/139450638
if checkpoint_path is not None:
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return True
return self.checkpoint_manager.restore_or_initialize()
def _evaluate_once(self, current_step):
"""Runs the evaluation once."""
logging.info("Start evaluation at step: %s", current_step)
with self.eval_summary_manager.summary_writer.as_default():
eval_outputs = self.eval_fn(self.eval_steps)
if eval_outputs:
eval_outputs = tf.nest.map_structure(lambda x: x.numpy(), eval_outputs)
info = "step: {} evaluation metric: {}".format(
current_step, eval_outputs)
self._log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs)
def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=force_trigger)
if ckpt_path is not None:
logging.info("Saved checkpoins in %s", ckpt_path)
def _maybe_evaluate(self, current_step, force_trigger=False):
if self.eval_trigger(current_step, force_trigger):
self._evaluate_once(current_step)
def _log_info(self, message):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
def train(self, evaluate=True):
"""Runs the training, with optional evaluation.
This handles evaluation, gathering summaries, and saving checkpoints.
Args:
evaluate: A boolean indicates whether to perform evaluation during
training.
Raises:
RuntimeError: If `global_step` is not updated correctly in `train_fn`.
"""
if self.train_fn is None:
raise ValueError("`self.train_fn` is required when calling `train` "
"method.")
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `train` "
"method.")
if evaluate and self.eval_fn is None:
raise ValueError("`self.eval_fn` is required when calling `train` method "
"with `evaluate=True`")
step_timer = _StepTimer(self.global_step)
current_step = self.global_step.numpy()
logging.info("Train at step %s of %s", current_step, self.train_steps)
while current_step < self.train_steps:
# Calculates steps to run for the next train loop.
steps_per_loop = min(self.train_steps - current_step, self.steps_per_loop)
logging.info("Entering training loop with %s steps, at step %s of %s",
steps_per_loop, current_step, self.train_steps)
current_step += steps_per_loop
steps_per_loop = tf.convert_to_tensor(steps_per_loop, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
train_outputs = self.train_fn(steps_per_loop)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
raise RuntimeError("`self.train_fn` is not updating `global_step` "
"correctly, expected: %s, actual: %s" %
(current_step, self.global_step.numpy()))
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(
lambda x: x.numpy(), train_outputs)
steps_per_second = step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
self._log_info(info)
train_outputs = train_outputs or {}
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
self._maybe_save_checkpoints(current_step)
if evaluate:
self._maybe_evaluate(current_step)
self.summary_manager.write_summaries(train_outputs, always_write=True)
self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate:
self._maybe_evaluate(current_step, force_trigger=True)
def evaluate(self, continuous=False, timeout_fn=None):
"""Runs the evaluation.
Args:
continuous: If `True`, will continously monitor the checkpoint directory
to evaluate on the latest checkpoint. If `False`, will do the evaluation
once.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
"""
if self.eval_fn is None:
raise ValueError("`self.eval_fn` should not be None to call "
"`evaluate()` method.")
if not continuous and timeout_fn is not None:
raise ValueError("`timeout_fn` can be only passed when `continuous` is "
"True")
if continuous:
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory, timeout_fn=timeout_fn):
self._restore_model(checkpoint_path)
self._evaluate_once(self.global_step.numpy())
return
latest_checkpoint = self.checkpoint_manager.latest_checkpoint
if not latest_checkpoint:
raise ValueError("no checkpoint found in dir %s" %
self.checkpoint_manager.directory)
self._restore_model()
self._evaluate_once(self.global_step.numpy())
class _StepTimer(object):
"""Utility class for measuring steps/second."""
def __init__(self, step):
self.step = step
self.start()
def start(self):
self.last_iteration = self.step.numpy()
self.last_time = time.time()
def steps_per_second(self, restart=True):
value = ((self.step.numpy() - self.last_iteration) /
(time.time() - self.last_time))
if restart:
self.start()
return value
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import six
import tensorflow.compat.v2 as tf
from typing import Dict, Optional, Text
@six.add_metaclass(abc.ABCMeta)
class AbstractTrainable(tf.Module):
"""An abstract class defining the APIs required for training."""
@abc.abstractmethod
def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""Implements model training with multiple steps.
In training, it is common to break the total training steps into several
training loops, so users can do checkpointing, write summaries and run some
python callbacks. This is necessary for getting good performance in TPU
training, as the overhead for launching a multi worker tf.function may be
large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.experimental_run_v2` inside a
`tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple
python loop to drive each step.
Args:
num_steps: A guideline for how many training steps to run. Note that it is
up to the model what constitutes a "step" (this may involve more than
one update to model parameters, e.g. if training a GAN).
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
@six.add_metaclass(abc.ABCMeta)
class AbstractEvaluable(tf.Module):
"""An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod
def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""Implements model evaluation.
Args:
num_steps: A guideline for how many evaluation steps to run. Note that it
is up to the model what constitutes a "step". Generally, it may be
desirable to support both a limited number of eval steps and iterating
over a full dataset (however many steps are required) when `num_steps`
is `None`.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import six
import tensorflow.compat.v2 as tf
from typing import Dict, Optional, Text
from official.staging.training import runnable
from official.staging.training import utils
@six.add_metaclass(abc.ABCMeta)
class StandardTrainable(runnable.AbstractTrainable):
"""Implements the standard functionality of AbstractTrainable APIs."""
def __init__(self, use_tf_while_loop=True, use_tf_function=True):
if use_tf_while_loop and not use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
self.use_tf_while_loop = use_tf_while_loop
self.use_tf_function = use_tf_function
self.train_dataset = None
self.train_iter = None
self.train_loop_fn = None
@abc.abstractmethod
def build_train_dataset(self):
"""Builds the training datasets.
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class."""
if self.train_dataset is None:
# Build train input dataset
self.train_dataset = self.build_train_dataset()
self.train_iter = tf.nest.map_structure(iter, self.train_dataset)
if self.train_loop_fn is None:
train_fn = self.train_step
if self.use_tf_while_loop:
self.train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
else:
if self.use_tf_function:
train_fn = tf.function(train_fn)
self.train_loop_fn = utils.create_loop_fn(train_fn)
self.train_loop_begin()
self.train_loop_fn(self.train_iter, num_steps)
return self.train_loop_end()
def train_loop_begin(self):
"""Called once at the beginning of the training loop.
This is a good place to reset metrics that accumulate values over multiple
steps of training.
"""
pass
@abc.abstractmethod
def train_step(self, iterator):
"""Implements one step of training.
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
"""
pass
def train_loop_end(self) -> Optional[Dict[Text, tf.Tensor]]:
"""Called at the end of the training loop.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the train() method.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
@six.add_metaclass(abc.ABCMeta)
class StandardEvaluable(runnable.AbstractEvaluable):
"""Implements the standard functionality of AbstractEvaluable APIs."""
def __init__(self, use_tf_function=True):
self.eval_use_tf_function = use_tf_function
self.eval_dataset = None
self.eval_loop_fn = None
@abc.abstractmethod
def build_eval_dataset(self):
"""Builds the evaluation datasets.
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class."""
if self.eval_dataset is None:
# Build train input dataset
self.eval_dataset = self.build_eval_dataset()
if self.eval_loop_fn is None:
eval_fn = self.eval_step
if self.eval_use_tf_function:
eval_fn = tf.function(eval_fn)
self.eval_loop_fn = utils.create_loop_fn(eval_fn)
# TODO(b/147718615): When async RPC is enabled in eager runtime, we make
# eval iterator as a class member so it doesn't get destroyed when out of
# the function scope.
self.eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
self.eval_begin()
self.eval_loop_fn(self.eval_iter, num_steps)
return self.eval_end()
def eval_begin(self):
"""Called once at the beginning of the evaluation.
This is a good place to reset metrics that accumulate values over the entire
evaluation.
"""
pass
@abc.abstractmethod
def eval_step(self, iterator):
"""Implements one step of evaluation.
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
"""
pass
def eval_end(self) -> Optional[Dict[Text, tf.Tensor]]:
"""Called at the end of the evaluation.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the evaluate() method.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import inspect
import six
import tensorflow.compat.v2 as tf
def create_loop_fn(step_fn):
"""Creates a multiple steps function driven by the python while loop.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps==-1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
`value` is the outputs from `step_fn`.
Returns:
The updated state.
"""
try:
step = 0
while (num_steps == -1 or step < num_steps):
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Create a multiple steps function driven by tf.while_loop on the host.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
@tf.function
def loop_fn(iterator, num_steps):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Must be a tf.Tensor.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError("`num_steps` should be an `tf.Tensor`. Python object "
"may cause retracing.")
for _ in tf.range(num_steps):
step_fn(iterator)
return loop_fn
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A helper function to create distributed dataset.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
`tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
it is a function, it could optionally have an argument named
`input_context` which is `tf.distribute.InputContext` argument type.
*args: The list of arguments to be passed to dataset_or_fn.
**kwargs: Any keyword arguments to be passed.
Returns:
A distributed Dataset.
"""
if strategy is None:
strategy = tf.distribute.get_strategy()
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`")
def dataset_fn(ctx):
"""Wrapped dataset function for creating distributed dataset.."""
# If `dataset_or_fn` is a function and has `input_context` as argument
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if six.PY3:
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn)
args_names = argspec.args
if "input_context" in args_names:
kwargs["input_context"] = ctx
ds = dataset_or_fn(*args, **kwargs)
return ds
return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object):
"""A class manages writing summaries."""
def __init__(self,
summary_writer,
summary_fn,
global_step=None,
summary_interval=None):
"""Construct a summary manager object.
Args:
summary_writer: A `tf.summary.SummaryWriter` instance for writing
summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for checking the current global step
value, in case users want to save summaries every N steps.
summary_interval: An integer, indicates the minimum step interval between
two summaries.
"""
if summary_writer is not None:
self._summary_writer = summary_writer
self._enabled = True
else:
self._summary_writer = tf.summary.create_noop_writer()
self._enabled = False
self._summary_fn = summary_fn
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
if summary_interval is not None:
if self._global_step is None:
raise ValueError("`summary_interval` is not None, but no `global_step` "
"can be obtained ")
self._last_summary_step = self._global_step.numpy()
self._summary_interval = summary_interval
@property
def summary_interval(self):
return self._summary_interval
@property
def summary_writer(self):
"""Returns the underlying summary writer."""
return self._summary_writer
def write_summaries(self, items, always_write=True):
"""Write a bulk of summaries.
Args:
items: a dictionary of `Tensors` for writing summaries.
always_write: An optional boolean. If `True`, the manager will always
write summaries unless the summaries have been written for the same
step. Otherwise the manager will only write the summaries if the
interval between summaries are larger than `summary_interval`.
Returns:
A boolean indicates whether the summaries are written or not.
"""
# TODO(rxsang): Support writing summaries with nested structure, so users
# can split the summaries into different directories for nicer visualization
# in Tensorboard, like train and eval metrics.
if not self._enabled:
return False
if self._summary_interval is not None:
current_step = self._global_step.numpy()
if current_step == self._last_summary_step:
return False
if not always_write and current_step < (self._last_summary_step +
self._summary_interval):
return False
self._last_summary_step = current_step
with self._summary_writer.as_default():
for name, tensor in items.items():
self._summary_fn(name, tensor, step=self._global_step)
return True
@six.add_metaclass(abc.ABCMeta)
class Trigger(object):
"""An abstract class representing a "trigger" for some event."""
@abc.abstractmethod
def __call__(self, value: float, force_trigger=False):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: Whether the trigger is forced triggered.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
@abc.abstractmethod
def reset(self):
"""Reset states in the trigger."""
class IntervalTrigger(Trigger):
"""Triggers on every fixed interval."""
def __init__(self, interval, start=0):
"""Constructs the IntervalTrigger.
Args:
interval: The triggering interval.
start: An initial value for the trigger.
"""
self._interval = interval
self._last_trigger_value = start
def __call__(self, value, force_trigger=False):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: If True, the trigger will be forced triggered unless the
last trigger value is equal to `value`.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
if force_trigger and value != self._last_trigger_value:
self._last_trigger_value = value
return True
if self._interval > 0:
if value >= self._last_trigger_value + self._interval:
self._last_trigger_value = value
return True
return False
def reset(self):
"""See base class."""
self._last_trigger_value = 0
class EpochHelper(object):
"""A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step):
"""Constructs the EpochHelper.
Args:
epoch_steps: An integer indicates how many steps in an epoch.
global_step: A `tf.Variable` instance indicates the current global step.
"""
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
self._current_epoch = self._global_step.numpy() / self._epoch_steps
self._in_epoch = True
return True
def epoch_end(self):
"""Returns whether the current epoch should end."""
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy()
epoch = current_step / self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def current_epoch(self):
return self._current_epoch
...@@ -346,10 +346,17 @@ def define_keras_flags( ...@@ -346,10 +346,17 @@ def define_keras_flags(
flags.DEFINE_string( flags.DEFINE_string(
name='tpu', default='', help='TPU address to connect to.') name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_integer( flags.DEFINE_integer(
name='steps_per_loop', default=1, name='steps_per_loop',
help='Number of steps per graph-mode loop. Only training step happens ' default=500,
help='Number of steps per training loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at ' 'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.') 'steps per epoch.')
flags.DEFINE_boolean(
name='use_tf_while_loop',
default=True,
help='Whether to build a tf.while_loop inside the training loop on the '
'host. Setting it to True is critical to have peak performance on '
'TPU.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='use_tf_keras_layers', default=False, name='use_tf_keras_layers', default=False,
help='Whether to use tf.keras.layers instead of tf.python.keras.layers.' help='Whether to use tf.keras.layers instead of tf.python.keras.layers.'
......
...@@ -18,21 +18,20 @@ from __future__ import absolute_import ...@@ -18,21 +18,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.vision.image_classification import imagenet_preprocessing from official.staging.training import controller
from official.vision.image_classification import common
from official.vision.image_classification import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_runnable
flags.DEFINE_boolean(name='use_tf_function', default=True, flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a ' help='Wrap the train and test step inside a '
...@@ -42,110 +41,42 @@ flags.DEFINE_boolean(name='single_l2_loss_op', default=False, ...@@ -42,110 +41,42 @@ flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
'instead of using Keras per-layer L2 loss.') 'instead of using Keras per-layer L2 loss.')
def build_stats(train_result, eval_result, time_callback, avg_exp_per_second): def build_stats(runnable, time_callback):
"""Normalizes and returns dictionary of stats. """Normalizes and returns dictionary of stats.
Args: Args:
train_result: The final loss at training time. runnable: The module containing all the training and evaluation metrics.
eval_result: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
time_callback: Time tracking callback instance. time_callback: Time tracking callback instance.
avg_exp_per_second: Average training examples per second.
Returns: Returns:
Dictionary of normalized results. Dictionary of normalized results.
""" """
stats = {} stats = {}
if eval_result: if not runnable.flags_obj.skip_eval:
stats['eval_loss'] = eval_result[0] stats['eval_loss'] = runnable.test_loss.result().numpy()
stats['eval_acc'] = eval_result[1] stats['eval_acc'] = runnable.test_accuracy.result().numpy()
stats['train_loss'] = train_result[0] stats['train_loss'] = runnable.train_loss.result().numpy()
stats['train_acc'] = train_result[1] stats['train_acc'] = runnable.train_accuracy.result().numpy()
if time_callback: if time_callback:
timestamp_log = time_callback.timestamp_log timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
avg_exp_per_second = tf.reduce_mean(
runnable.examples_per_second_history).numpy(),
stats['avg_exp_per_second'] = avg_exp_per_second stats['avg_exp_per_second'] = avg_exp_per_second
return stats return stats
def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj)
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
batch_size = flags_obj.batch_size
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
if flags_obj.use_synthetic_data:
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype,
drop_remainder=True)
else:
input_fn = imagenet_preprocessing.input_fn
def _train_dataset_fn(ctx=None):
train_ds = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype,
input_context=ctx,
drop_remainder=True)
return train_ds
if strategy:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
train_ds = strategy.experimental_distribute_datasets_from_function(_train_dataset_fn)
else:
train_ds = strategy.experimental_distribute_dataset(_train_dataset_fn())
else:
train_ds = _train_dataset_fn()
test_ds = None
if not flags_obj.skip_eval:
def _test_data_fn(ctx=None):
test_ds = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype,
input_context=ctx)
return test_ds
if strategy:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
test_ds = strategy.experimental_distribute_datasets_from_function(
_test_data_fn)
else:
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
else:
test_ds = _test_data_fn()
return train_ds, test_ds
def get_num_train_iterations(flags_obj): def get_num_train_iterations(flags_obj):
"""Returns the number of training steps, train and test epochs.""" """Returns the number of training steps, train and test epochs."""
train_steps = ( train_steps = (
...@@ -214,223 +145,53 @@ def run(flags_obj): ...@@ -214,223 +145,53 @@ def run(flags_obj):
num_packs=flags_obj.num_packs, num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu) tpu_address=flags_obj.tpu)
train_ds, test_ds = get_input_dataset(flags_obj, strategy)
per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
flags_obj) flags_obj)
steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
logging.info("Training %d epochs, each epoch has %d steps, "
"total steps: %d; Eval %d steps", logging.info(
train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, 'Training %d epochs, each epoch has %d steps, '
eval_steps) 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
train_epochs * per_epoch_steps, eval_steps)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size, time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
flags_obj.log_steps) flags_obj.log_steps)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers) runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
model = resnet_model.resnet50( per_epoch_steps)
num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size, eval_interval = (
use_l2_regularizer=not flags_obj.single_l2_loss_op) flags_obj.epochs_between_evals *
per_epoch_steps if not flags_obj.skip_eval else None)
lr_schedule = common.PiecewiseConstantDecayWithWarmup( checkpoint_interval = (
batch_size=flags_obj.batch_size, per_epoch_steps if flags_obj.enable_checkpoint_and_export else None)
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None
warmup_epochs=common.LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), checkpoint_manager = tf.train.CheckpointManager(
multipliers=list(p[0] for p in common.LR_SCHEDULE), runnable.checkpoint,
compute_lr_on_cpu=True) directory=flags_obj.model_dir,
optimizer = common.get_optimizer(lr_schedule) max_to_keep=10,
step_counter=runnable.global_step,
if dtype == tf.float16: checkpoint_interval=checkpoint_interval)
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( resnet_controller = controller.Controller(
optimizer, loss_scale) strategy,
elif flags_obj.fp16_implementation == 'graph_rewrite': runnable.train,
# `dtype` is still float32 in this case. We built the graph in float32 and runnable.evaluate,
# let the graph rewrite change parts of it float16. global_step=runnable.global_step,
if not flags_obj.use_tf_function: steps_per_loop=steps_per_loop,
raise ValueError('--fp16_implementation=graph_rewrite requires ' train_steps=per_epoch_steps * train_epochs,
'--use_tf_function to be true') checkpoint_manager=checkpoint_manager,
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) summary_interval=summary_interval,
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( eval_steps=eval_steps,
optimizer, loss_scale) eval_interval=eval_interval)
current_step = 0 time_callback.on_train_begin()
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) resnet_controller.train(evaluate=True)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir) time_callback.on_train_end()
if latest_checkpoint:
checkpoint.restore(latest_checkpoint) stats = build_stats(runnable, time_callback)
logging.info("Load checkpoint %s", latest_checkpoint) return stats
current_step = optimizer.iterations.numpy()
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def step_fn(inputs):
"""Per-Replica StepFn."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in trainable_variables
if 'bn' not in v.name
])
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, trainable_variables))
train_loss.update_state(loss)
training_accuracy.update_state(labels, logits)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop."""
for _ in tf.range(steps):
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
def train_single_step(iterator):
if strategy:
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
else:
return step_fn(next(iterator))
def test_step(iterator):
"""Evaluation StepFn."""
def step_fn(inputs):
images, labels = inputs
logits = model(images, training=False)
loss = tf.keras.losses.sparse_categorical_crossentropy(labels,
logits)
loss = tf.reduce_sum(loss) * (1.0/ flags_obj.batch_size)
test_loss.update_state(loss)
test_accuracy.update_state(labels, logits)
if strategy:
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
else:
step_fn(next(iterator))
if flags_obj.use_tf_function:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
if flags_obj.enable_tensorboard:
summary_writer = tf.summary.create_file_writer(flags_obj.model_dir)
else:
summary_writer = None
examples_per_second_history = []
train_iter = iter(train_ds)
time_callback.on_train_begin()
for epoch in range(current_step // per_epoch_steps, train_epochs):
train_loss.reset_states()
training_accuracy.reset_states()
steps_in_current_epoch = 0
time_callback.on_epoch_begin(epoch + 1)
while steps_in_current_epoch < per_epoch_steps:
time_callback.on_batch_begin(
steps_in_current_epoch+epoch*per_epoch_steps)
steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
steps_per_loop)
if steps == 1:
train_single_step(train_iter)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32))
time_callback.on_batch_end(
steps_in_current_epoch+epoch*per_epoch_steps)
steps_in_current_epoch += steps
logging.info('Training loss: %s, accuracy: %s at epoch %d',
train_loss.result().numpy(),
training_accuracy.result().numpy(),
epoch + 1)
time_callback.on_epoch_end(epoch + 1)
epoch_time = time_callback.epoch_runtime_log[-1]
steps_per_second = per_epoch_steps / epoch_time
examples_per_second = steps_per_second * flags_obj.batch_size
examples_per_second_history.append(examples_per_second)
if (not flags_obj.skip_eval and
(epoch + 1) % flags_obj.epochs_between_evals == 0):
test_loss.reset_states()
test_accuracy.reset_states()
test_iter = iter(test_ds)
for _ in range(eval_steps):
test_step(test_iter)
logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
test_loss.result().numpy(),
test_accuracy.result().numpy(),
epoch + 1)
if flags_obj.enable_checkpoint_and_export:
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
'model.ckpt-{}'.format(epoch + 1)))
logging.info('Saved checkpoint to %s', checkpoint_name)
if summary_writer:
current_steps = steps_in_current_epoch + (epoch * per_epoch_steps)
with summary_writer.as_default():
tf.summary.scalar('train_loss', train_loss.result(), current_steps)
tf.summary.scalar(
'train_accuracy', training_accuracy.result(), current_steps)
tf.summary.scalar('eval_loss', test_loss.result(), current_steps)
tf.summary.scalar(
'eval_accuracy', test_accuracy.result(), current_steps)
tf.summary.scalar('global_step/sec', steps_per_second, current_steps)
tf.summary.scalar('examples/sec', examples_per_second, current_steps)
time_callback.on_train_end()
if summary_writer:
summary_writer.close()
eval_result = None
train_result = None
if not flags_obj.skip_eval:
eval_result = [test_loss.result().numpy(),
test_accuracy.result().numpy()]
train_result = [train_loss.result().numpy(),
training_accuracy.result().numpy()]
stats = build_stats(
train_result,
eval_result,
time_callback,
tf.reduce_mean(examples_per_second_history).numpy(),
)
return stats
def main(_): def main(_):
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tempfile import tempfile
import os
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
...@@ -64,9 +65,10 @@ class CtlImagenetTest(tf.test.TestCase): ...@@ -64,9 +65,10 @@ class CtlImagenetTest(tf.test.TestCase):
def test_end_to_end_no_dist_strat(self): def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy.""" """Test Keras model with 1 GPU, no distribution strategy."""
model_dir = os.path.join(self.get_temp_dir(), 'ctl_imagenet_no_dist_strat')
extra_flags = [ extra_flags = [
'-distribution_strategy', 'off', '-distribution_strategy', 'off',
'-model_dir', 'ctl_imagenet_no_dist_strat', '-model_dir', model_dir,
'-data_format', 'channels_last', '-data_format', 'channels_last',
] ]
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
...@@ -83,10 +85,11 @@ class CtlImagenetTest(tf.test.TestCase): ...@@ -83,10 +85,11 @@ class CtlImagenetTest(tf.test.TestCase):
if context.num_gpus() < 2: if context.num_gpus() < 2:
num_gpus = '0' num_gpus = '0'
model_dir = os.path.join(self.get_temp_dir(), 'ctl_imagenet_2_gpu')
extra_flags = [ extra_flags = [
'-num_gpus', num_gpus, '-num_gpus', num_gpus,
'-distribution_strategy', 'mirrored', '-distribution_strategy', 'mirrored',
'-model_dir', 'ctl_imagenet_2_gpu', '-model_dir', model_dir,
'-data_format', 'channels_last', '-data_format', 'channels_last',
] ]
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from official.staging.training import standard_runnable
from official.staging.training import utils
from official.utils.flags import core as flags_core
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
class ResnetRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
"""Implements the training and evaluation APIs for Resnet model."""
def __init__(self, flags_obj, time_callback, epoch_steps):
standard_runnable.StandardTrainable.__init__(self,
flags_obj.use_tf_while_loop,
flags_obj.use_tf_function)
standard_runnable.StandardEvaluable.__init__(self,
flags_obj.use_tf_function)
self.strategy = tf.distribute.get_strategy()
self.flags_obj = flags_obj
self.dtype = flags_core.get_tf_dtype(flags_obj)
self.time_callback = time_callback
# Input pipeline related
batch_size = flags_obj.batch_size
if batch_size % self.strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
self.strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)
if self.flags_obj.use_synthetic_data:
self.input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=self.dtype,
drop_remainder=True)
else:
self.input_fn = imagenet_preprocessing.input_fn
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
self.model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=common.LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
self.optimizer = common.get_optimizer(lr_schedule)
# Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations
if self.dtype == tf.float16:
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
self.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
self.optimizer, loss_scale))
elif flags_obj.fp16_implementation == 'graph_rewrite':
# `dtype` is still float32 in this case. We built the graph in float32
# and let the graph rewrite change parts of it float16.
if not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
self.optimizer = (
tf.train.experimental.enable_mixed_precision_graph_rewrite(
self.optimizer, loss_scale))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'train_accuracy', dtype=tf.float32)
self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
self.checkpoint = tf.train.Checkpoint(
model=self.model, optimizer=self.optimizer)
# Handling epochs.
self.epoch_steps = epoch_steps
self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
self.examples_per_second_history = []
def build_train_dataset(self):
"""See base class."""
return utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=True,
data_dir=self.flags_obj.data_dir,
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=self.flags_obj
.datasets_num_private_threads,
dtype=self.dtype,
drop_remainder=True)
def build_eval_dataset(self):
"""See base class."""
return utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=False,
data_dir=self.flags_obj.data_dir,
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype)
def train_loop_begin(self):
"""See base class."""
# Reset all metrics
self.train_loss.reset_states()
self.train_accuracy.reset_states()
self.time_callback.on_batch_begin(self.global_step)
self._epoch_begin()
def train_step(self, iterator):
"""See base class."""
def step_fn(inputs):
"""Function to run on the device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = self.model(images, training=True)
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss = tf.reduce_sum(prediction_loss) * (1.0 /
self.flags_obj.batch_size)
num_replicas = self.strategy.num_replicas_in_sync
if self.flags_obj.single_l2_loss_op:
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in self.model.trainable_variables
if 'bn' not in v.name
])
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(self.model.losses) / num_replicas)
# Scale the loss
if self.flags_obj.dtype == 'fp16':
loss = self.optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, self.model.trainable_variables)
# Unscale the grads
if self.flags_obj.dtype == 'fp16':
grads = self.optimizer.get_unscaled_gradients(grads)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
self.train_loss.update_state(loss)
self.train_accuracy.update_state(labels, logits)
self.strategy.experimental_run_v2(step_fn, args=(next(iterator),))
def train_loop_end(self):
"""See base class."""
self.time_callback.on_batch_end(self.global_step)
self._epoch_end()
return {
'train_loss': self.train_loss.result(),
'train_accuracy': self.train_accuracy.result(),
}
def eval_begin(self):
"""See base class."""
self.test_loss.reset_states()
self.test_accuracy.reset_states()
def eval_step(self, iterator):
"""See base class."""
def step_fn(inputs):
"""Function to run on the device."""
images, labels = inputs
logits = self.model(images, training=False)
loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits)
loss = tf.reduce_sum(loss) * (1.0 / self.flags_obj.batch_size)
self.test_loss.update_state(loss)
self.test_accuracy.update_state(labels, logits)
self.strategy.experimental_run_v2(step_fn, args=(next(iterator),))
def eval_end(self):
"""See base class."""
return {
'test_loss': self.test_loss.result(),
'test_accuracy': self.test_accuracy.result()
}
def _epoch_begin(self):
if self.epoch_helper.epoch_begin():
self.time_callback.on_epoch_begin(self.epoch_helper.current_epoch)
def _epoch_end(self):
if self.epoch_helper.epoch_end():
self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)
epoch_time = self.time_callback.epoch_runtime_log[-1]
steps_per_second = self.epoch_steps / epoch_time
examples_per_second = steps_per_second * self.flags_obj.batch_size
self.examples_per_second_history.append(examples_per_second)
tf.summary.scalar('global_step/sec', steps_per_second)
tf.summary.scalar('examples/sec', examples_per_second)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment