Commit 05631eec authored by liangjing's avatar liangjing
Browse files

version 1

parent 7e0391d9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import logging
import tensorflow as tf
from typing import Callable, Dict, Optional, Text
from tf2_common.training import utils
class Controller(object):
"""Class that facilitates training and evaluation of models."""
def __init__(
self,
strategy: Optional[tf.distribute.Strategy] = None,
train_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
eval_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
warmup_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
global_step: Optional[tf.Variable] = None,
# Train related
train_steps: Optional[int] = None,
steps_per_loop: Optional[int] = None,
summary_dir: Optional[Text] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# summary related
summary_interval: Optional[int] = None,
# Evaluation related
eval_summary_dir: Optional[Text] = None,
eval_steps: Optional[int] = None,
eval_interval: Optional[int] = None,
eval_offset: Optional[int] = 0,
# Warmup related
device_warmup_steps: Optional[int] = None):
"""Constructs a `Controller` instance.
Args:
strategy: An instance of `tf.distribute.Strategy`.
train_fn: A callable defined as `def train_fn(num_steps)`, which
`num_steps` indicates the number of steps to run for each loop.
eval_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps`
indicates the number of steps for one evaluation.
warmup_fn: A callable defined as `def warmup_fn(num_steps)`, which
`num_steps` indicates the number of steps to run for each loop.
global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can
create their own global step variable as well. If the users create their
own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
train_steps: The total (maximum) number of training steps to perform.
steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `train_fn`).
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries outside the training loop. If the
value is None, then training summaries are not enabled.
eval_summary_dir: The directory to write eval summaries. If None, no eval
summary will be written.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip
evaluation. Note that evaluation only happens outside the training loop,
which the loop iteration is specify by `steps_per_loop` parameter.
eval_offset: Step number of the first evaluation.
device_warmup_steps: The number of steps to run for warmup.
Raises:
ValueError: If both `train_fn` and `eval_fn` are None.
ValueError: If `train_fn` is not None and `train_steps` is None.
ValueError: If `steps_per_loop` is None when `train_fn` is provided.
ValueError: If `steps_per_loop` is not a positive integer.
"""
if train_fn is None and eval_fn is None:
raise ValueError("`train_fn` and `eval_fn` should not both be None")
# TODO(rxsang): Support training until exhaustion by passing
# `train_steps=-1`. Currently it cannot be supported with a host training
# loop because break statements are not supported with distributed dataset.
if train_fn is not None and train_steps is None:
raise ValueError("`train_steps` is required when `train_fn` is provided.")
if train_fn is not None and steps_per_loop is None:
raise ValueError("`steps_per_loop` is required when `train_fn is "
"provided.")
if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
raise ValueError("`steps_per_loop` should be a positive integer")
if summary_interval is not None and summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
self.strategy = strategy or tf.distribute.get_strategy()
self.train_fn = train_fn
self.eval_fn = eval_fn
self.warmup_fn = warmup_fn
self.global_step = global_step
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
self.device_warmup_steps = device_warmup_steps
self.summary_dir = summary_dir or checkpoint_manager.directory
self.checkpoint_manager = checkpoint_manager
self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer(
self.summary_dir) if self.summary_interval else None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
summary_writer,
tf.summary.scalar,
global_step=self.global_step,
summary_interval=self.summary_interval)
if self.global_step:
tf.summary.experimental.set_step(self.global_step)
eval_summary_writer = tf.summary.create_file_writer(
eval_summary_dir) if eval_summary_dir else None
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps
self.eval_interval = eval_interval
self.eval_offset = eval_offset
# Restore Model if needed.
if self.checkpoint_manager is not None:
model_restored = self._restore_model()
if not model_restored and self.checkpoint_manager.checkpoint_interval:
# If the model is not restored from a checkpoint, save an initial
# checkpoint.
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path)
# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.eval_offset)
def _restore_model(self, checkpoint_path=None):
"""Restore or initialize the model.
Args:
checkpoint_path: An optional string indicates the checkpoint path to
restore. If None, will restore from `self.checkpoint_manager`.
Returns:
True if the latest checkpoint is found or restored. Otherwise False.
"""
with self.strategy.scope():
# Checkpoint restoring should be inside scope. b/139450638
if checkpoint_path is not None:
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return True
return self.checkpoint_manager.restore_or_initialize()
def _evaluate_once(self, current_step):
"""Runs the evaluation once."""
logging.info("Start evaluation at step: %s", current_step)
with self.eval_summary_manager.summary_writer.as_default():
eval_outputs = self.eval_fn(self.eval_steps)
if eval_outputs:
eval_outputs = tf.nest.map_structure(
lambda x: (x if isinstance(x, (float, bool)) else x.numpy()),
eval_outputs)
info = "step: {} evaluation metric: {}".format(
current_step, eval_outputs)
self._log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs)
if "continue_training" in eval_outputs.keys():
return eval_outputs["continue_training"]
else:
return True
def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=force_trigger)
if ckpt_path is not None:
logging.info("Saved checkpoins in %s", ckpt_path)
def _maybe_evaluate(self, current_step, force_trigger=False):
if self.eval_trigger(current_step, force_trigger):
return self._evaluate_once(current_step)
else:
return True
def _log_info(self, message):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
def train(self, evaluate=True):
"""Runs the training, with optional evaluation.
This handles evaluation, gathering summaries, and saving checkpoints.
Args:
evaluate: A boolean indicates whether to perform evaluation during
training.
Raises:
RuntimeError: If `global_step` is not updated correctly in `train_fn`.
"""
if self.train_fn is None:
raise ValueError("`self.train_fn` is required when calling `train` "
"method.")
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `train` "
"method.")
if evaluate and self.eval_fn is None:
raise ValueError("`self.eval_fn` is required when calling `train` method "
"with `evaluate=True`")
step_timer = _StepTimer(self.global_step)
current_step = self.global_step.numpy()
logging.info("Train at step %s of %s", current_step, self.train_steps)
while current_step < self.train_steps:
# Calculates steps to run for the next train loop.
steps_per_loop = min(self.train_steps - current_step, self.steps_per_loop)
logging.info("Entering training loop with %s steps, at step %s of %s",
steps_per_loop, current_step, self.train_steps)
current_step += steps_per_loop
steps_per_loop = tf.convert_to_tensor(steps_per_loop, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
train_outputs = self.train_fn(steps_per_loop)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
logging.warning("`self.train_fn` is not updating `global_step` "
"correctly, expected: %s, actual: %s",
current_step, self.global_step.numpy())
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(
lambda x: x.numpy(), train_outputs)
steps_per_second = step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
self._log_info(info)
train_outputs = train_outputs or {}
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
self._maybe_save_checkpoints(current_step)
if evaluate:
continue_training = self._maybe_evaluate(current_step)
if not continue_training:
break
self.summary_manager.write_summaries(train_outputs, always_write=True)
self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate:
self._maybe_evaluate(current_step, force_trigger=True)
def evaluate(self, continuous=False, timeout_fn=None):
"""Runs the evaluation.
Args:
continuous: If `True`, will continously monitor the checkpoint directory
to evaluate on the latest checkpoint. If `False`, will do the evaluation
once.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
"""
if self.eval_fn is None:
raise ValueError("`self.eval_fn` should not be None to call "
"`evaluate()` method.")
if not continuous and timeout_fn is not None:
raise ValueError("`timeout_fn` can be only passed when `continuous` is "
"True")
if continuous:
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory, timeout_fn=timeout_fn):
self._restore_model(checkpoint_path)
self._evaluate_once(self.global_step.numpy())
return
latest_checkpoint = self.checkpoint_manager.latest_checkpoint
if not latest_checkpoint:
raise ValueError("no checkpoint found in dir %s" %
self.checkpoint_manager.directory)
self._restore_model()
self._evaluate_once(self.global_step.numpy())
def warmup(self):
"""Runs device warmup.
This handles running a training loop on dummy data to move TF function
tracing and XLA compilation outside of the training loop.
"""
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `warmup` "
"method.")
step_timer = _StepTimer(self.global_step)
current_step = self.global_step.numpy()
logging.info("Warmup at step %s of %s", current_step,
self.device_warmup_steps)
while current_step < self.device_warmup_steps:
# Calculates steps to run for the next train loop.
steps_per_loop = self.device_warmup_steps
logging.info("Entering warmup loop with %s steps, at step %s of %s",
steps_per_loop, current_step, self.device_warmup_steps)
current_step += steps_per_loop
steps_per_loop = tf.convert_to_tensor(steps_per_loop, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
self.warmup_fn(steps_per_loop)
steps_per_second = step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f}".format(
current_step, steps_per_second)
self._log_info(info)
class _StepTimer(object):
"""Utility class for measuring steps/second."""
def __init__(self, step):
self.step = step
self.start()
def start(self):
self.last_iteration = self.step.numpy()
self.last_time = time.time()
def steps_per_second(self, restart=True):
value = ((self.step.numpy() - self.last_iteration) /
(time.time() - self.last_time))
if restart:
self.start()
return value
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Modified optimizer_v2 implementation enabling XLA across variable updates."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables as tf_variables
class OptimizerV2Modified(optimizer_v2.OptimizerV2):
"""This is a subclass optimizer that performs variable updates in
Distribution Strategy replica context. OptimizerV2 base class is currently
under refactoring and will have better support of this.
Please refer to optimizer_v2.OptimizerV2 for more details regarding the APIs.
"""
def __init__(self, name, use_experimental_compile=False, **kwargs):
"""Create a new Optimizer.
Args:
name: Optional name prefix for variables and ops created by the optimizer.
use_experimental_compile: when set to True, use experimental_compile on
the _distributed_apply function.
"""
super(OptimizerV2Modified, self).__init__(name=name, **kwargs)
self.use_experimental_compile = use_experimental_compile
def apply_gradients(self,
grads_and_vars,
name=None,
experimental_aggregate_gradients=True):
"""Apply gradients to variables.
Only the last two lines are different from optimizer_v2.OptimizerV2.
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the name passed
to the `Optimizer` constructor.
experimental_aggregate_gradients: Whether to sum gradients from different
replicas in the presense of `tf.distribute.Strategy`. If False, it's
user responsibility to aggregate the gradients. Default to True.
Returns:
An `Operation` that applies the specified gradients. The `iterations`
will be automatically increased by 1.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
RuntimeError: If called in cross-replica context.
"""
# pylint: disable=protected-access
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
# pylint: enable=protected-access
var_list = [v for (_, v) in grads_and_vars]
with ops.name_scope_v2(self._name):
# Create iteration if necessary.
with ops.init_scope():
self._create_all_weights(var_list)
if not grads_and_vars:
# Distribution strategy does not support reducing an empty list of
# gradients
return control_flow_ops.no_op()
if distribute_ctx.in_cross_replica_context():
raise RuntimeError(
"`apply_gradients() cannot be called in cross-replica context. "
"Use `tf.distribute.Strategy.run` to enter replica "
"context.")
strategy = distribute_ctx.get_strategy()
if (not experimental_aggregate_gradients and strategy and isinstance(
strategy.extended,
parameter_server_strategy.ParameterServerStrategyExtended)):
raise NotImplementedError(
"`experimental_aggregate_gradients=False is not supported for "
"ParameterServerStrategy and CentralStorageStrategy")
apply_state = self._prepare(var_list)
if experimental_aggregate_gradients:
grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars)
grads_and_vars = self._aggregate_gradients(grads_and_vars)
grads_and_vars = self._transform_gradients(grads_and_vars)
self._distributed_apply(None, grads_and_vars, name, apply_state)
return self._iterations.assign_add(1, read_value=False)
def _distributed_apply_org(self, distribution, grads_and_vars, name, apply_state):
"""`apply_gradients` using a `DistributionStrategy`.
This is the _distributed_apply function in optimizer_v2,
returning a list of ops.
"""
def apply_grad_to_update_var(var, grad):
"""Apply gradient to variable."""
if isinstance(var, ops.Tensor):
raise NotImplementedError("Trying to update a Tensor ", var)
apply_kwargs = {}
if isinstance(grad, ops.IndexedSlices):
if var.constraint is not None:
raise RuntimeError(
"Cannot use a constraint function on a sparse variable.")
if "apply_state" in self._sparse_apply_args:
apply_kwargs["apply_state"] = apply_state
return self._resource_apply_sparse_duplicate_indices(
grad.values, var, grad.indices, **apply_kwargs)
if "apply_state" in self._dense_apply_args:
apply_kwargs["apply_state"] = apply_state
update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
if var.constraint is not None:
with ops.control_dependencies([update_op]):
return var.assign(var.constraint(var))
else:
return update_op
update_ops = []
with ops.name_scope(name or self._name, skip_on_eager=True):
for grad, var in grads_and_vars:
update_ops.append(apply_grad_to_update_var(var, grad))
return control_flow_ops.group(*update_ops)
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
if self.use_experimental_compile:
self._distributed_apply_compile(distribution, grads_and_vars, name,
apply_state)
else:
self._distributed_apply_org(distribution, grads_and_vars, name,
apply_state)
@tf.function(experimental_compile=False)
def _distributed_apply_compile(self, distribution, grads_and_vars, name,
apply_state):
"""This is a warpper, to return a tensor, making tf.func() happy."""
self._distributed_apply_org(distribution, grads_and_vars,
name, apply_state)
return tf.ones((), dtype=tf.bool)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment