Commit 9af441a4 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 274807747
parent ddd45b81
......@@ -27,14 +27,13 @@ from absl import logging
import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from typing import Optional, Dict, Text, Callable, Union, Iterator, Any
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS
# TODO(yeqing): Move all the flags out of this file.
def define_common_hparams_flags():
"""Define the common flags across models."""
......@@ -145,6 +144,13 @@ def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
logging.info('Saving model as TF checkpoint: %s', saved_path)
def _steps_to_run(current_step, total_steps, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
return min(total_steps - current_step, steps_per_loop)
def _no_metric():
return None
......@@ -270,7 +276,34 @@ class DistributedExecutor(object):
input_data = input_fn(self._params)
return iter(strategy.experimental_distribute_dataset(input_data))
# TODO(yeqing): Extract the train_step out as a class for re-usability.
def _create_replicated_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
return _replicated_step
def _create_train_step(self,
strategy,
model,
......@@ -290,9 +323,11 @@ class DistributedExecutor(object):
Returns:
The training step callable.
"""
_replicated_step = self._create_replicated_step(strategy, model, loss_fn,
optimizer, metric)
@tf.function
def train_step(iterator):
def train_step(iterator, num_steps):
"""Performs a distributed training step.
Args:
......@@ -301,28 +336,15 @@ class DistributedExecutor(object):
Returns:
The loss tensor.
"""
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
if not isinstance(num_steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
......@@ -368,6 +390,7 @@ class DistributedExecutor(object):
summary_writer_fn: Callable[[Text, Text],
SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None,
custom_callbacks: List[tf.keras.callbacks.Callback] = None,
save_config: bool = True):
"""Runs distributed training.
......@@ -384,6 +407,9 @@ class DistributedExecutor(object):
eval_metric_fn: metric_fn for evaluation in test_step.
summary_writer_fn: function to create summary writer.
init_checkpoint: function to load checkpoint.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
save_config: bool. Whether to save params to model_dir.
Returns:
......@@ -399,6 +425,25 @@ class DistributedExecutor(object):
train_metric_fn = train_metric_fn or _no_metric
eval_metric_fn = eval_metric_fn or _no_metric
if custom_callbacks and iterations_per_loop != 1:
logging.error(
'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)', iterations_per_loop)
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
if save_config:
self._save_config(model_dir)
......@@ -419,7 +464,6 @@ class DistributedExecutor(object):
optimizer = model.optimizer
# Training loop starts here.
# TODO(yeqing): Implementing checkpoints with Callbacks.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
initial_step = 0
......@@ -446,19 +490,24 @@ class DistributedExecutor(object):
# Continue training loop.
train_step = self._create_train_step(
strategy, model, self.loss_fn(), optimizer, metric=train_metric)
strategy=strategy,
model=model,
loss_fn=self.loss_fn(),
optimizer=optimizer,
metric=train_metric)
test_step = None
if eval_input_fn and eval_metric:
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Training started')
for step in range(initial_step, total_steps):
while current_step < total_steps:
current_step = step + 1
train_loss = train_step(train_iterator)
if current_step % iterations_per_loop != 0:
# Skip metric if run less than one epoch.
continue
num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
_run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator,
tf.convert_to_tensor(num_steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += num_steps
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
train_loss)
......@@ -487,7 +536,8 @@ class DistributedExecutor(object):
train_summary_writer(
metrics=train_metric_result, step=optimizer.iterations)
# Saves model checkpoints and run validation steps at every epoch end.
# Saves model checkpoints and run validation steps at every
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_steps:
......
......@@ -46,29 +46,18 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
self._predict_post_process_fn = predict_post_process_fn
self._trainable_variables_filter = trainable_variables_filter
def _create_train_step(self,
def _create_replicated_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
"""Creates a distributed training step."""
@tf.function
def train_step(iterator):
"""Performs a distributed training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors.
metric: eval metrics to be run outside the graph.
Returns:
The loss tensor.
"""
trainable_variables = model.trainable_variables
if self._trainable_variables_filter:
trainable_variables = self._trainable_variables_filter(
trainable_variables)
logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables))
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -88,26 +77,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
trainable_variables = model.trainable_variables
if self._trainable_variables_filter:
trainable_variables = self._trainable_variables_filter(
trainable_variables)
logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables))
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
# return losses, labels
return loss
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
return train_step
return _replicated_step
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment