Commit 7a02b5ce authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Restore dummy computation function.

PiperOrigin-RevId: 364388006
parent 2da86542
......@@ -351,32 +351,32 @@ def load_fine_tune_checkpoint(
features, labels = iter(input_dataset).next()
# @tf.function
# def _dummy_computation_fn(features, labels):
# model._is_training = False # pylint: disable=protected-access
# tf.keras.backend.set_learning_phase(False)
# labels = model_lib.unstack_batch(
# labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
# return _compute_losses_and_predictions_dicts(
# model,
# features,
# labels)
# strategy = tf.compat.v2.distribute.get_strategy()
# if hasattr(tf.distribute.Strategy, 'run'):
# strategy.run(
# _dummy_computation_fn, args=(
# features,
# labels,
# ))
# else:
# strategy.experimental_run_v2(
# _dummy_computation_fn, args=(
# features,
# labels,
# ))
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(
model,
features,
labels)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
else:
strategy.experimental_run_v2(
_dummy_computation_fn, args=(
features,
labels,
))
restore_from_objects_dict = model.restore_from_objects(
fine_tune_checkpoint_type=checkpoint_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment