Commit 08273bc2 authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Simplify dummy computation to ensure all model variables are built properly.

PiperOrigin-RevId: 364924351
parent d3c73c21
...@@ -216,6 +216,7 @@ class CheckpointV2Test(tf.test.TestCase): ...@@ -216,6 +216,7 @@ class CheckpointV2Test(tf.test.TestCase):
model_lib_v2.load_fine_tune_checkpoint( model_lib_v2.load_fine_tune_checkpoint(
self._model, self._ckpt_path, checkpoint_type='', self._model, self._ckpt_path, checkpoint_type='',
checkpoint_version=train_pb2.CheckpointVersion.V2, checkpoint_version=train_pb2.CheckpointVersion.V2,
run_model_on_dummy_input=True,
input_dataset=self._train_input_fn(), input_dataset=self._train_input_fn(),
unpad_groundtruth_tensors=True) unpad_groundtruth_tensors=True)
np.testing.assert_allclose(self._model.weight.numpy(), 42) np.testing.assert_allclose(self._model.weight.numpy(), 42)
...@@ -228,6 +229,7 @@ class CheckpointV2Test(tf.test.TestCase): ...@@ -228,6 +229,7 @@ class CheckpointV2Test(tf.test.TestCase):
model_lib_v2.load_fine_tune_checkpoint( model_lib_v2.load_fine_tune_checkpoint(
IncompatibleModel(), self._ckpt_path, checkpoint_type='', IncompatibleModel(), self._ckpt_path, checkpoint_type='',
checkpoint_version=train_pb2.CheckpointVersion.V2, checkpoint_version=train_pb2.CheckpointVersion.V2,
run_model_on_dummy_input=True,
input_dataset=self._train_input_fn(), input_dataset=self._train_input_fn(),
unpad_groundtruth_tensors=True) unpad_groundtruth_tensors=True)
......
...@@ -143,6 +143,42 @@ def _compute_losses_and_predictions_dicts( ...@@ -143,6 +143,42 @@ def _compute_losses_and_predictions_dicts(
return losses_dict, prediction_dict return losses_dict, prediction_dict
def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors):
"""Ensures that model variables are all built, by running on a dummy input.
Args:
model: A DetectionModel to be built.
input_dataset: The tf.data Dataset the model is being trained on. Needed to
get the shapes for the dummy loss computation.
unpad_groundtruth_tensors: A parameter passed to unstack_batch.
"""
features, labels = iter(input_dataset).next()
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(model, features, labels)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
else:
strategy.experimental_run_v2(
_dummy_computation_fn, args=(
features,
labels,
))
# TODO(kaftan): Explore removing learning_rate from this method & returning # TODO(kaftan): Explore removing learning_rate from this method & returning
## The full losses dict instead of just total_loss, then doing all summaries ## The full losses dict instead of just total_loss, then doing all summaries
## saving in a utility method called by the outer training loop. ## saving in a utility method called by the outer training loop.
...@@ -311,9 +347,9 @@ def is_object_based_checkpoint(checkpoint_path): ...@@ -311,9 +347,9 @@ def is_object_based_checkpoint(checkpoint_path):
return '_CHECKPOINTABLE_OBJECT_GRAPH' in var_names return '_CHECKPOINTABLE_OBJECT_GRAPH' in var_names
def load_fine_tune_checkpoint( def load_fine_tune_checkpoint(model, checkpoint_path, checkpoint_type,
model, checkpoint_path, checkpoint_type, checkpoint_version, input_dataset, checkpoint_version, run_model_on_dummy_input,
unpad_groundtruth_tensors): input_dataset, unpad_groundtruth_tensors):
"""Load a fine tuning classification or detection checkpoint. """Load a fine tuning classification or detection checkpoint.
To make sure the model variables are all built, this method first executes To make sure the model variables are all built, this method first executes
...@@ -335,6 +371,9 @@ def load_fine_tune_checkpoint( ...@@ -335,6 +371,9 @@ def load_fine_tune_checkpoint(
checkpoint_version: train_pb2.CheckpointVersion.V1 or V2 enum indicating checkpoint_version: train_pb2.CheckpointVersion.V1 or V2 enum indicating
whether to load checkpoints in V1 style or V2 style. In this binary whether to load checkpoints in V1 style or V2 style. In this binary
we only support V2 style (object-based) checkpoints. we only support V2 style (object-based) checkpoints.
run_model_on_dummy_input: Whether to run the model on a dummy input in order
to ensure that all model variables have been built successfully before
loading the fine_tune_checkpoint.
input_dataset: The tf.data Dataset the model is being trained on. Needed input_dataset: The tf.data Dataset the model is being trained on. Needed
to get the shapes for the dummy loss computation. to get the shapes for the dummy loss computation.
unpad_groundtruth_tensors: A parameter passed to unstack_batch. unpad_groundtruth_tensors: A parameter passed to unstack_batch.
...@@ -349,34 +388,8 @@ def load_fine_tune_checkpoint( ...@@ -349,34 +388,8 @@ def load_fine_tune_checkpoint(
if checkpoint_version == train_pb2.CheckpointVersion.V1: if checkpoint_version == train_pb2.CheckpointVersion.V1:
raise ValueError('Checkpoint version should be V2') raise ValueError('Checkpoint version should be V2')
features, labels = iter(input_dataset).next() if run_model_on_dummy_input:
_ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors)
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(
model,
features,
labels)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
else:
strategy.experimental_run_v2(
_dummy_computation_fn, args=(
features,
labels,
))
restore_from_objects_dict = model.restore_from_objects( restore_from_objects_dict = model.restore_from_objects(
fine_tune_checkpoint_type=checkpoint_type) fine_tune_checkpoint_type=checkpoint_type)
...@@ -516,13 +529,6 @@ def train_loop( ...@@ -516,13 +529,6 @@ def train_loop(
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
def train_dataset_fn(input_context): def train_dataset_fn(input_context):
"""Callable to create train input.""" """Callable to create train input."""
...@@ -545,7 +551,14 @@ def train_loop( ...@@ -545,7 +551,14 @@ def train_loop(
aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA) aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
optimizer, (learning_rate,) = optimizer_builder.build( optimizer, (learning_rate,) = optimizer_builder.build(
train_config.optimizer, global_step=global_step) train_config.optimizer, global_step=global_step)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
if train_config.optimizer.use_moving_average: if train_config.optimizer.use_moving_average:
_ensure_model_is_built(detection_model, train_input,
unpad_groundtruth_tensors)
optimizer.shadow_copy(detection_model) optimizer.shadow_copy(detection_model)
if callable(learning_rate): if callable(learning_rate):
...@@ -577,12 +590,11 @@ def train_loop( ...@@ -577,12 +590,11 @@ def train_loop(
lambda: global_step % num_steps_per_iteration == 0): lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint. # Load a fine-tuning checkpoint.
if train_config.fine_tune_checkpoint: if train_config.fine_tune_checkpoint:
load_fine_tune_checkpoint(detection_model, load_fine_tune_checkpoint(
train_config.fine_tune_checkpoint, detection_model, train_config.fine_tune_checkpoint,
fine_tune_checkpoint_type, fine_tune_checkpoint_type, fine_tune_checkpoint_version,
fine_tune_checkpoint_version, train_config.run_fine_tune_checkpoint_dummy_computation,
train_input, train_input, unpad_groundtruth_tensors)
unpad_groundtruth_tensors)
ckpt = tf.compat.v2.train.Checkpoint( ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model, optimizer=optimizer) step=global_step, model=detection_model, optimizer=optimizer)
...@@ -1080,13 +1092,6 @@ def eval_continuously( ...@@ -1080,13 +1092,6 @@ def eval_continuously(
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
# dummy_image, dummy_shapes = detection_model.preprocess(
# tf.zeros([1, 512, 512, 3], dtype=tf.float32))
# dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
eval_input = strategy.experimental_distribute_dataset( eval_input = strategy.experimental_distribute_dataset(
inputs.eval_input( inputs.eval_input(
...@@ -1106,7 +1111,14 @@ def eval_continuously( ...@@ -1106,7 +1111,14 @@ def eval_continuously(
ckpt = tf.compat.v2.train.Checkpoint( ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model, optimizer=optimizer) step=global_step, model=detection_model, optimizer=optimizer)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
if eval_config.use_moving_averages: if eval_config.use_moving_averages:
unpad_groundtruth_tensors = (eval_config.batch_size == 1 and not use_tpu)
_ensure_model_is_built(detection_model, eval_input,
unpad_groundtruth_tensors)
optimizer.shadow_copy(detection_model) optimizer.shadow_copy(detection_model)
ckpt.restore(latest_checkpoint).expect_partial() ckpt.restore(latest_checkpoint).expect_partial()
......
...@@ -14,7 +14,7 @@ enum CheckpointVersion { ...@@ -14,7 +14,7 @@ enum CheckpointVersion {
// Message for configuring DetectionModel training jobs (train.py). // Message for configuring DetectionModel training jobs (train.py).
// Next id: 30 // Next id: 31
message TrainConfig { message TrainConfig {
// Effective batch size to use for training. // Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
...@@ -66,6 +66,18 @@ message TrainConfig { ...@@ -66,6 +66,18 @@ message TrainConfig {
// will raise an error. Instead, set fine_tune_checkpoint_type: 'full'. // will raise an error. Instead, set fine_tune_checkpoint_type: 'full'.
optional bool load_all_detection_checkpoint_vars = 19 [default = false]; optional bool load_all_detection_checkpoint_vars = 19 [default = false];
// Whether to run dummy computation when loading a `fine_tune_checkpoint`.
// This option is true by default since it is often necessary to run the model
// on a dummy input before loading a `fine_tune_checkpoint`, in order to
// ensure that all the model variables have alread been built successfully.
// Some meta architectures, like CenterNet, do not require dummy computation
// to successfully load all checkpoint variables, and in these cases this
// flag may be set to false to reduce startup time and memory consumption.
// Note, this flag only affects dummy computation when loading a
// `fine_tune_checkpoint`, e.g. it does not affect the dummy computation that
// is run when creating shadow copies of model variables when using EMA.
optional bool run_fine_tune_checkpoint_dummy_computation = 30 [default=true];
// Number of steps to train the DetectionModel for. If 0, will train the model // Number of steps to train the DetectionModel for. If 0, will train the model
// indefinitely. // indefinitely.
optional uint32 num_steps = 9 [default=0]; optional uint32 num_steps = 9 [default=0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment