Commit 6a67bfdc authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Internal change.

PiperOrigin-RevId: 364349645
parent 0c48b89f
......@@ -17,6 +17,9 @@
import functools
import sys
from absl import logging
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
......@@ -1064,6 +1067,7 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
if center_net_config.HasField('post_processing'):
non_max_suppression_fn, _ = post_processing_builder.build(
center_net_config.post_processing)
return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training,
add_summaries=add_summaries,
......
......@@ -351,32 +351,32 @@ def load_fine_tune_checkpoint(
features, labels = iter(input_dataset).next()
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(
model,
features,
labels)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
else:
strategy.experimental_run_v2(
_dummy_computation_fn, args=(
features,
labels,
))
# @tf.function
# def _dummy_computation_fn(features, labels):
# model._is_training = False # pylint: disable=protected-access
# tf.keras.backend.set_learning_phase(False)
# labels = model_lib.unstack_batch(
# labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
# return _compute_losses_and_predictions_dicts(
# model,
# features,
# labels)
# strategy = tf.compat.v2.distribute.get_strategy()
# if hasattr(tf.distribute.Strategy, 'run'):
# strategy.run(
# _dummy_computation_fn, args=(
# features,
# labels,
# ))
# else:
# strategy.experimental_run_v2(
# _dummy_computation_fn, args=(
# features,
# labels,
# ))
restore_from_objects_dict = model.restore_from_objects(
fine_tune_checkpoint_type=checkpoint_type)
......@@ -1084,9 +1084,9 @@ def eval_continuously(
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
# dummy_image, dummy_shapes = detection_model.preprocess(
# tf.zeros([1, 512, 512, 3], dtype=tf.float32))
# dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
eval_input = strategy.experimental_distribute_dataset(
inputs.eval_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment