Commit 08273bc2 authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Simplify dummy computation to ensure all model variables are built properly.

PiperOrigin-RevId: 364924351
parent d3c73c21
......@@ -216,6 +216,7 @@ class CheckpointV2Test(tf.test.TestCase):
model_lib_v2.load_fine_tune_checkpoint(
self._model, self._ckpt_path, checkpoint_type='',
checkpoint_version=train_pb2.CheckpointVersion.V2,
run_model_on_dummy_input=True,
input_dataset=self._train_input_fn(),
unpad_groundtruth_tensors=True)
np.testing.assert_allclose(self._model.weight.numpy(), 42)
......@@ -228,6 +229,7 @@ class CheckpointV2Test(tf.test.TestCase):
model_lib_v2.load_fine_tune_checkpoint(
IncompatibleModel(), self._ckpt_path, checkpoint_type='',
checkpoint_version=train_pb2.CheckpointVersion.V2,
run_model_on_dummy_input=True,
input_dataset=self._train_input_fn(),
unpad_groundtruth_tensors=True)
......
......@@ -143,6 +143,42 @@ def _compute_losses_and_predictions_dicts(
return losses_dict, prediction_dict
def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors):
"""Ensures that model variables are all built, by running on a dummy input.
Args:
model: A DetectionModel to be built.
input_dataset: The tf.data Dataset the model is being trained on. Needed to
get the shapes for the dummy loss computation.
unpad_groundtruth_tensors: A parameter passed to unstack_batch.
"""
features, labels = iter(input_dataset).next()
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(model, features, labels)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
else:
strategy.experimental_run_v2(
_dummy_computation_fn, args=(
features,
labels,
))
# TODO(kaftan): Explore removing learning_rate from this method & returning
## The full losses dict instead of just total_loss, then doing all summaries
## saving in a utility method called by the outer training loop.
......@@ -311,9 +347,9 @@ def is_object_based_checkpoint(checkpoint_path):
return '_CHECKPOINTABLE_OBJECT_GRAPH' in var_names
def load_fine_tune_checkpoint(
model, checkpoint_path, checkpoint_type, checkpoint_version, input_dataset,
unpad_groundtruth_tensors):
def load_fine_tune_checkpoint(model, checkpoint_path, checkpoint_type,
checkpoint_version, run_model_on_dummy_input,
input_dataset, unpad_groundtruth_tensors):
"""Load a fine tuning classification or detection checkpoint.
To make sure the model variables are all built, this method first executes
......@@ -335,6 +371,9 @@ def load_fine_tune_checkpoint(
checkpoint_version: train_pb2.CheckpointVersion.V1 or V2 enum indicating
whether to load checkpoints in V1 style or V2 style. In this binary
we only support V2 style (object-based) checkpoints.
run_model_on_dummy_input: Whether to run the model on a dummy input in order
to ensure that all model variables have been built successfully before
loading the fine_tune_checkpoint.
input_dataset: The tf.data Dataset the model is being trained on. Needed
to get the shapes for the dummy loss computation.
unpad_groundtruth_tensors: A parameter passed to unstack_batch.
......@@ -349,34 +388,8 @@ def load_fine_tune_checkpoint(
if checkpoint_version == train_pb2.CheckpointVersion.V1:
raise ValueError('Checkpoint version should be V2')
features, labels = iter(input_dataset).next()
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(
model,
features,
labels)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
else:
strategy.experimental_run_v2(
_dummy_computation_fn, args=(
features,
labels,
))
if run_model_on_dummy_input:
_ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors)
restore_from_objects_dict = model.restore_from_objects(
fine_tune_checkpoint_type=checkpoint_type)
......@@ -516,13 +529,6 @@ def train_loop(
with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
def train_dataset_fn(input_context):
"""Callable to create train input."""
......@@ -545,7 +551,14 @@ def train_loop(
aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
optimizer, (learning_rate,) = optimizer_builder.build(
train_config.optimizer, global_step=global_step)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
if train_config.optimizer.use_moving_average:
_ensure_model_is_built(detection_model, train_input,
unpad_groundtruth_tensors)
optimizer.shadow_copy(detection_model)
if callable(learning_rate):
......@@ -577,12 +590,11 @@ def train_loop(
lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint.
if train_config.fine_tune_checkpoint:
load_fine_tune_checkpoint(detection_model,
train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type,
fine_tune_checkpoint_version,
train_input,
unpad_groundtruth_tensors)
load_fine_tune_checkpoint(
detection_model, train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type, fine_tune_checkpoint_version,
train_config.run_fine_tune_checkpoint_dummy_computation,
train_input, unpad_groundtruth_tensors)
ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model, optimizer=optimizer)
......@@ -1080,13 +1092,6 @@ def eval_continuously(
with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
# dummy_image, dummy_shapes = detection_model.preprocess(
# tf.zeros([1, 512, 512, 3], dtype=tf.float32))
# dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
eval_input = strategy.experimental_distribute_dataset(
inputs.eval_input(
......@@ -1106,7 +1111,14 @@ def eval_continuously(
ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model, optimizer=optimizer)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
if eval_config.use_moving_averages:
unpad_groundtruth_tensors = (eval_config.batch_size == 1 and not use_tpu)
_ensure_model_is_built(detection_model, eval_input,
unpad_groundtruth_tensors)
optimizer.shadow_copy(detection_model)
ckpt.restore(latest_checkpoint).expect_partial()
......
......@@ -14,7 +14,7 @@ enum CheckpointVersion {
// Message for configuring DetectionModel training jobs (train.py).
// Next id: 30
// Next id: 31
message TrainConfig {
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
......@@ -66,6 +66,18 @@ message TrainConfig {
// will raise an error. Instead, set fine_tune_checkpoint_type: 'full'.
optional bool load_all_detection_checkpoint_vars = 19 [default = false];
// Whether to run dummy computation when loading a `fine_tune_checkpoint`.
// This option is true by default since it is often necessary to run the model
// on a dummy input before loading a `fine_tune_checkpoint`, in order to
// ensure that all the model variables have alread been built successfully.
// Some meta architectures, like CenterNet, do not require dummy computation
// to successfully load all checkpoint variables, and in these cases this
// flag may be set to false to reduce startup time and memory consumption.
// Note, this flag only affects dummy computation when loading a
// `fine_tune_checkpoint`, e.g. it does not affect the dummy computation that
// is run when creating shadow copies of model variables when using EMA.
optional bool run_fine_tune_checkpoint_dummy_computation = 30 [default=true];
// Number of steps to train the DetectionModel for. If 0, will train the model
// indefinitely.
optional uint32 num_steps = 9 [default=0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment