Commit 90d1a0bb authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Enable evaluation under distribution strategy.

Run inference under distribution strategy, gather outputs locally and evaluate the results with coco tools on cpu.

PiperOrigin-RevId: 341162083
parent 970f6567
...@@ -891,7 +891,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config): ...@@ -891,7 +891,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
def eval_input(eval_config, eval_input_config, model_config, def eval_input(eval_config, eval_input_config, model_config,
model=None, params=None): model=None, params=None, input_context=None):
"""Returns `features` and `labels` tensor dictionaries for evaluation. """Returns `features` and `labels` tensor dictionaries for evaluation.
Args: Args:
...@@ -901,6 +901,9 @@ def eval_input(eval_config, eval_input_config, model_config, ...@@ -901,6 +901,9 @@ def eval_input(eval_config, eval_input_config, model_config,
model: A pre-constructed Detection Model. model: A pre-constructed Detection Model.
If None, one will be created from the config. If None, one will be created from the config.
params: Parameter dictionary passed from the estimator. params: Parameter dictionary passed from the estimator.
input_context: optional, A tf.distribute.InputContext object used to
shard filenames and compute per-replica batch_size when this function
is being called per-replica.
Returns: Returns:
A tf.data.Dataset that holds (features, labels) tuple. A tf.data.Dataset that holds (features, labels) tuple.
...@@ -1021,6 +1024,7 @@ def eval_input(eval_config, eval_input_config, model_config, ...@@ -1021,6 +1024,7 @@ def eval_input(eval_config, eval_input_config, model_config,
eval_input_config, eval_input_config,
batch_size=params['batch_size'] if params else eval_config.batch_size, batch_size=params['batch_size'] if params else eval_config.batch_size,
transform_input_data_fn=transform_and_pad_input_data_fn, transform_input_data_fn=transform_and_pad_input_data_fn,
input_context=input_context,
reduce_to_frame_fn=reduce_to_frame_fn) reduce_to_frame_fn=reduce_to_frame_fn)
return dataset return dataset
......
...@@ -104,10 +104,10 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -104,10 +104,10 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
numpy array of keypoint visibilities with shape [num_gt_boxes, numpy array of keypoint visibilities with shape [num_gt_boxes,
num_keypoints]. Integer is treated as an enum with 0=not labeled, num_keypoints]. Integer is treated as an enum with 0=not labeled,
1=labeled but not visible and 2=labeled and visible. 1=labeled but not visible and 2=labeled and visible.
InputDataFields.groundtruth_labeled_classes (optional): a dictionary of InputDataFields.groundtruth_labeled_classes (optional): a tensor of
image_id to groundtruth_labeled_class, where groundtruth_labeled_class shape [num_classes + 1] containing the multi-hot tensor indicating the
is a 1-indexed integer numpy array indicating which classes have been classes that each image is labeled for. Note that the classes labels
annotated over the image. are 1-indexed.
""" """
if image_id in self._image_ids: if image_id in self._image_ids:
tf.logging.warning('Ignoring ground truth with image id %s since it was ' tf.logging.warning('Ignoring ground truth with image id %s since it was '
...@@ -150,8 +150,19 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -150,8 +150,19 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
self._annotation_id += groundtruth_dict[standard_fields.InputDataFields. self._annotation_id += groundtruth_dict[standard_fields.InputDataFields.
groundtruth_boxes].shape[0] groundtruth_boxes].shape[0]
self._groundtruth_labeled_classes[image_id] = groundtruth_dict.get( if (standard_fields.InputDataFields.groundtruth_labeled_classes
standard_fields.InputDataFields.groundtruth_labeled_classes) ) in groundtruth_dict:
labeled_classes = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_labeled_classes]
if labeled_classes.shape != (len(self._category_id_set) + 1,):
raise ValueError('Invalid shape for groundtruth labeled classes: {}, '
'num_categories_including_background: {}'.format(
labeled_classes,
len(self._category_id_set) + 1))
self._groundtruth_labeled_classes[image_id] = np.flatnonzero(
groundtruth_dict[standard_fields.InputDataFields
.groundtruth_labeled_classes] == 1).tolist()
# Boolean to indicate whether a detection has been added for this image. # Boolean to indicate whether a detection has been added for this image.
self._image_ids[image_id] = False self._image_ids[image_id] = False
...@@ -373,7 +384,11 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -373,7 +384,11 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
# detection_classes. This assumes that all predictions will be kept to # detection_classes. This assumes that all predictions will be kept to
# compute eval metrics. # compute eval metrics.
if groundtruth_labeled_classes is None: if groundtruth_labeled_classes is None:
groundtruth_labeled_classes = detection_classes groundtruth_labeled_classes = tf.reduce_max(
tf.one_hot(
tf.cast(detection_classes, tf.int32),
len(self._category_id_set) + 1),
axis=-2)
if not image_id.shape.as_list(): if not image_id.shape.as_list():
# Apply a batch dimension to all tensors. # Apply a batch dimension to all tensors.
......
...@@ -390,7 +390,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase): ...@@ -390,7 +390,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
np.array([1]), np.array([1]),
# Only class 1 is exhaustively labeled for image1. # Only class 1 is exhaustively labeled for image1.
groundtruth_labeled_classes: groundtruth_labeled_classes:
np.array([1]), np.array([0., 1., 0., 0.]),
detection_boxes: detection_boxes:
np.array([[100., 100., 200., 200.], [100., 100., 200., np.array([[100., 100., 200., 200.], [100., 100., 200.,
200.]]), 200.]]),
...@@ -405,7 +405,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase): ...@@ -405,7 +405,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
image_id: 'image2', image_id: 'image2',
groundtruth_boxes: np.array([[50., 50., 100., 100.]]), groundtruth_boxes: np.array([[50., 50., 100., 100.]]),
groundtruth_classes: np.array([3]), groundtruth_classes: np.array([3]),
groundtruth_labeled_classes: np.array([3]), groundtruth_labeled_classes: np.array([0., 0., 0., 1.]),
detection_boxes: np.array([[50., 50., 100., 100.]]), detection_boxes: np.array([[50., 50., 100., 100.]]),
detection_scores: np.array([.7]), detection_scores: np.array([.7]),
detection_classes: np.array([3]) detection_classes: np.array([3])
...@@ -416,7 +416,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase): ...@@ -416,7 +416,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
image_id: 'image3', image_id: 'image3',
groundtruth_boxes: np.array([[25., 25., 50., 50.]]), groundtruth_boxes: np.array([[25., 25., 50., 50.]]),
groundtruth_classes: np.array([2]), groundtruth_classes: np.array([2]),
groundtruth_labeled_classes: np.array([2]), groundtruth_labeled_classes: np.array([0., 0., 1., 0.]),
detection_boxes: np.array([[25., 25., 50., 50.]]), detection_boxes: np.array([[25., 25., 50., 50.]]),
detection_scores: np.array([.9]), detection_scores: np.array([.9]),
detection_classes: np.array([2]) detection_classes: np.array([2])
......
...@@ -200,22 +200,11 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic, ...@@ -200,22 +200,11 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
if detection_model.groundtruth_has_field( if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_labeled_classes): input_data_fields.groundtruth_labeled_classes):
labeled_classes_list = detection_model.groundtruth_lists( groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.pad(
input_data_fields.groundtruth_labeled_classes) tf.stack(
labeled_classes = [ detection_model.groundtruth_lists(
tf.where(x)[:, 0] + label_id_offset for x in labeled_classes_list input_data_fields.groundtruth_labeled_classes)),
] label_id_offset_paddings)
if len(labeled_classes) > 1:
num_classes = labeled_classes_list[0].shape[0]
padded_labeled_classes = []
for x in labeled_classes:
padding = num_classes - tf.shape(x)[0]
padded_labeled_classes.append(tf.pad(x, [[0, padding]]))
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.stack(
padded_labeled_classes)
else:
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.stack(
labeled_classes)
groundtruth[input_data_fields.num_groundtruth_boxes] = ( groundtruth[input_data_fields.num_groundtruth_boxes] = (
tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]])) tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
...@@ -832,12 +821,14 @@ def create_estimator_and_inputs(run_config, ...@@ -832,12 +821,14 @@ def create_estimator_and_inputs(run_config,
train_config=train_config, train_config=train_config,
train_input_config=train_input_config, train_input_config=train_input_config,
model_config=model_config) model_config=model_config)
eval_input_fns = [ eval_input_fns = []
create_eval_input_fn( for eval_input_config in eval_input_configs:
eval_config=eval_config, eval_input_fns.append(
eval_input_config=eval_input_config, create_eval_input_fn(
model_config=model_config) for eval_input_config in eval_input_configs eval_config=eval_config,
] eval_input_config=eval_input_config,
model_config=model_config))
eval_input_names = [ eval_input_names = [
eval_input_config.name for eval_input_config in eval_input_configs eval_input_config.name for eval_input_config in eval_input_configs
] ]
......
...@@ -90,7 +90,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -90,7 +90,7 @@ class ModelLibTest(tf.test.TestCase):
config_kwarg_overrides = _get_config_kwarg_overrides() config_kwarg_overrides = _get_config_kwarg_overrides()
train_steps = 2 train_steps = 2
strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0') strategy = tf2.distribute.MirroredStrategy(['/cpu:0', '/cpu:1'])
with strategy.scope(): with strategy.scope():
model_lib_v2.train_loop( model_lib_v2.train_loop(
new_pipeline_config_path, new_pipeline_config_path,
......
...@@ -36,13 +36,6 @@ from object_detection.utils import label_map_util ...@@ -36,13 +36,6 @@ from object_detection.utils import label_map_util
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import visualization_utils as vutils from object_detection.utils import visualization_utils as vutils
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import tpu as contrib_tpu
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP
...@@ -664,6 +657,106 @@ def train_loop( ...@@ -664,6 +657,106 @@ def train_loop(
clean_temporary_directories(strategy, summary_writer_filepath) clean_temporary_directories(strategy, summary_writer_filepath)
def prepare_eval_dict(detections, groundtruth, features):
"""Prepares eval dictionary containing detections and groundtruth.
Takes in `detections` from the model, `groundtruth` and `features` returned
from the eval tf.data.dataset and creates a dictionary of tensors suitable
for detection eval modules.
Args:
detections: A dictionary of tensors returned by `model.postprocess`.
groundtruth: `inputs.eval_input` returns an eval dataset of (features,
labels) tuple. `groundtruth` must be set to `labels`.
Please note that:
* fields.InputDataFields.groundtruth_classes must be 0-indexed and
in its 1-hot representation.
* fields.InputDataFields.groundtruth_verified_neg_classes must be
0-indexed and in its multi-hot repesentation.
* fields.InputDataFields.groundtruth_not_exhaustive_classes must be
0-indexed and in its multi-hot repesentation.
* fields.InputDataFields.groundtruth_labeled_classes must be
0-indexed and in its multi-hot repesentation.
features: `inputs.eval_input` returns an eval dataset of (features, labels)
tuple. This argument must be set to a dictionary containing the following
keys and their corresponding values from `features` --
* fields.InputDataFields.image
* fields.InputDataFields.original_image
* fields.InputDataFields.original_image_spatial_shape
* fields.InputDataFields.true_image_shape
* inputs.HASH_KEY
Returns:
eval_dict: A dictionary of tensors to pass to eval module.
class_agnostic: Whether to evaluate detection in class agnostic mode.
"""
groundtruth_boxes = groundtruth[fields.InputDataFields.groundtruth_boxes]
groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
# For class-agnostic models, groundtruth one-hot encodings collapse to all
# ones.
class_agnostic = (
fields.DetectionResultFields.detection_classes not in detections)
if class_agnostic:
groundtruth_classes_one_hot = tf.ones(
[groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
else:
groundtruth_classes_one_hot = groundtruth[
fields.InputDataFields.groundtruth_classes]
label_id_offset = 1 # Applying label id offset (b/63711816)
groundtruth_classes = (
tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset)
groundtruth[fields.InputDataFields.groundtruth_classes] = groundtruth_classes
label_id_offset_paddings = tf.constant([[0, 0], [1, 0]])
if fields.InputDataFields.groundtruth_verified_neg_classes in groundtruth:
groundtruth[
fields.InputDataFields.groundtruth_verified_neg_classes] = tf.pad(
groundtruth[
fields.InputDataFields.groundtruth_verified_neg_classes],
label_id_offset_paddings)
if fields.InputDataFields.groundtruth_not_exhaustive_classes in groundtruth:
groundtruth[
fields.InputDataFields.groundtruth_not_exhaustive_classes] = tf.pad(
groundtruth[
fields.InputDataFields.groundtruth_not_exhaustive_classes],
label_id_offset_paddings)
if fields.InputDataFields.groundtruth_labeled_classes in groundtruth:
groundtruth[fields.InputDataFields.groundtruth_labeled_classes] = tf.pad(
groundtruth[fields.InputDataFields.groundtruth_labeled_classes],
label_id_offset_paddings)
use_original_images = fields.InputDataFields.original_image in features
if use_original_images:
eval_images = features[fields.InputDataFields.original_image]
true_image_shapes = features[fields.InputDataFields.true_image_shape][:, :3]
original_image_spatial_shapes = features[
fields.InputDataFields.original_image_spatial_shape]
else:
eval_images = features[fields.InputDataFields.image]
true_image_shapes = None
original_image_spatial_shapes = None
eval_dict = eval_util.result_dict_for_batched_example(
eval_images,
features[inputs.HASH_KEY],
detections,
groundtruth,
class_agnostic=class_agnostic,
scale_to_absolute=True,
original_image_spatial_shapes=original_image_spatial_shapes,
true_image_shapes=true_image_shapes)
return eval_dict, class_agnostic
def concat_replica_results(tensor_dict):
new_tensor_dict = {}
for key, values in tensor_dict.items():
new_tensor_dict[key] = tf.concat(values, axis=0)
return new_tensor_dict
def eager_eval_loop( def eager_eval_loop(
detection_model, detection_model,
configs, configs,
...@@ -692,6 +785,7 @@ def eager_eval_loop( ...@@ -692,6 +785,7 @@ def eager_eval_loop(
Returns: Returns:
A dict of evaluation metrics representing the results of this evaluation. A dict of evaluation metrics representing the results of this evaluation.
""" """
del postprocess_on_cpu
train_config = configs['train_config'] train_config = configs['train_config']
eval_input_config = configs['eval_input_config'] eval_input_config = configs['eval_input_config']
eval_config = configs['eval_config'] eval_config = configs['eval_config']
...@@ -735,57 +829,26 @@ def eager_eval_loop( ...@@ -735,57 +829,26 @@ def eager_eval_loop(
unpad_groundtruth_tensors = (boxes_shape[1] is not None unpad_groundtruth_tensors = (boxes_shape[1] is not None
and not use_tpu and not use_tpu
and batch_size == 1) and batch_size == 1)
groundtruth_dict = labels
labels = model_lib.unstack_batch( labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
losses_dict, prediction_dict = _compute_losses_and_predictions_dicts( losses_dict, prediction_dict = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss) detection_model, features, labels, add_regularization_loss)
prediction_dict = detection_model.postprocess(
def postprocess_wrapper(args): prediction_dict, features[fields.InputDataFields.true_image_shape])
return detection_model.postprocess(args[0], args[1]) eval_features = {
fields.InputDataFields.image:
# TODO(kaftan): Depending on how postprocessing will work for TPUS w/ features[fields.InputDataFields.image],
## TPUStrategy, may be good to move wrapping to a utility method fields.InputDataFields.original_image:
if use_tpu and postprocess_on_cpu: features[fields.InputDataFields.original_image],
detections = contrib_tpu.outside_compilation( fields.InputDataFields.original_image_spatial_shape:
postprocess_wrapper, features[fields.InputDataFields.original_image_spatial_shape],
(prediction_dict, features[fields.InputDataFields.true_image_shape])) fields.InputDataFields.true_image_shape:
else: features[fields.InputDataFields.true_image_shape],
detections = postprocess_wrapper( inputs.HASH_KEY: features[inputs.HASH_KEY],
(prediction_dict, features[fields.InputDataFields.true_image_shape])) }
return losses_dict, prediction_dict, groundtruth_dict, eval_features
class_agnostic = (
fields.DetectionResultFields.detection_classes not in detections)
# TODO(kaftan) (or anyone): move `_prepare_groundtruth_for_eval to eval_util
## and call this from there.
groundtruth = model_lib._prepare_groundtruth_for_eval( # pylint: disable=protected-access
detection_model, class_agnostic, eval_input_config.max_number_of_boxes)
use_original_images = fields.InputDataFields.original_image in features
if use_original_images:
eval_images = features[fields.InputDataFields.original_image]
true_image_shapes = tf.slice(
features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
original_image_spatial_shapes = features[
fields.InputDataFields.original_image_spatial_shape]
else:
eval_images = features[fields.InputDataFields.image]
true_image_shapes = None
original_image_spatial_shapes = None
keys = features[inputs.HASH_KEY]
if eval_input_config.include_source_id:
keys = features[fields.InputDataFields.source_id]
eval_dict = eval_util.result_dict_for_batched_example(
eval_images,
keys,
detections,
groundtruth,
class_agnostic=class_agnostic,
scale_to_absolute=True,
original_image_spatial_shapes=original_image_spatial_shapes,
true_image_shapes=true_image_shapes)
return eval_dict, losses_dict, class_agnostic
agnostic_categories = label_map_util.create_class_agnostic_category_index() agnostic_categories = label_map_util.create_class_agnostic_category_index()
per_class_categories = label_map_util.create_category_index_from_labelmap( per_class_categories = label_map_util.create_category_index_from_labelmap(
...@@ -793,9 +856,31 @@ def eager_eval_loop( ...@@ -793,9 +856,31 @@ def eager_eval_loop(
keypoint_edges = [ keypoint_edges = [
(kp.start, kp.end) for kp in eval_config.keypoint_edge] (kp.start, kp.end) for kp in eval_config.keypoint_edge]
for i, (features, labels) in enumerate(eval_dataset): strategy = tf.compat.v2.distribute.get_strategy()
eval_dict, losses_dict, class_agnostic = compute_eval_dict(features, labels)
for i, (features, labels) in enumerate(eval_dataset):
try:
(losses_dict, prediction_dict, groundtruth_dict,
eval_features) = strategy.run(
compute_eval_dict, args=(features, labels))
except: # pylint:disable=bare-except
tf.logging.info('A replica probably exhausted all examples. Skipping '
'pending examples on other replicas.')
break
(local_prediction_dict, local_groundtruth_dict,
local_eval_features) = tf.nest.map_structure(
strategy.experimental_local_results,
[prediction_dict, groundtruth_dict, eval_features])
local_prediction_dict = concat_replica_results(local_prediction_dict)
local_groundtruth_dict = concat_replica_results(local_groundtruth_dict)
local_eval_features = concat_replica_results(local_eval_features)
eval_dict, class_agnostic = prepare_eval_dict(local_prediction_dict,
local_groundtruth_dict,
local_eval_features)
for loss_key, loss_tensor in iter(losses_dict.items()):
losses_dict[loss_key] = strategy.reduce(tf.distribute.ReduceOp.MEAN,
loss_tensor, None)
if class_agnostic: if class_agnostic:
category_index = agnostic_categories category_index = agnostic_categories
else: else:
...@@ -841,20 +926,15 @@ def eager_eval_loop( ...@@ -841,20 +926,15 @@ def eager_eval_loop(
for loss_key, loss_tensor in iter(losses_dict.items()): for loss_key, loss_tensor in iter(losses_dict.items()):
if loss_key not in loss_metrics: if loss_key not in loss_metrics:
loss_metrics[loss_key] = tf.keras.metrics.Mean() loss_metrics[loss_key] = []
# Skip the loss with value equal or lower than 0.0 when calculating the loss_metrics[loss_key].append(loss_tensor)
# average loss since they don't usually reflect the normal loss values
# causing spurious average loss value.
if loss_tensor <= 0.0:
continue
loss_metrics[loss_key].update_state(loss_tensor)
eval_metrics = {} eval_metrics = {}
for evaluator in evaluators: for evaluator in evaluators:
eval_metrics.update(evaluator.evaluate()) eval_metrics.update(evaluator.evaluate())
for loss_key in loss_metrics: for loss_key in loss_metrics:
eval_metrics[loss_key] = loss_metrics[loss_key].result() eval_metrics[loss_key] = tf.reduce_mean(loss_metrics[loss_key])
eval_metrics = {str(k): v for k, v in eval_metrics.items()} eval_metrics = {str(k): v for k, v in eval_metrics.items()}
tf.logging.info('Eval metrics at step %d', global_step) tf.logging.info('Eval metrics at step %d', global_step)
...@@ -878,7 +958,7 @@ def eval_continuously( ...@@ -878,7 +958,7 @@ def eval_continuously(
checkpoint_dir=None, checkpoint_dir=None,
wait_interval=180, wait_interval=180,
timeout=3600, timeout=3600,
eval_index=None, eval_index=0,
**kwargs): **kwargs):
"""Run continuous evaluation of a detection model eagerly. """Run continuous evaluation of a detection model eagerly.
...@@ -908,8 +988,8 @@ def eval_continuously( ...@@ -908,8 +988,8 @@ def eval_continuously(
new checkpoint. new checkpoint.
timeout: The maximum number of seconds to wait for a checkpoint. Execution timeout: The maximum number of seconds to wait for a checkpoint. Execution
will terminate if no new checkpoints are found after these many seconds. will terminate if no new checkpoints are found after these many seconds.
eval_index: int, optional If give, only evaluate the dataset at the given eval_index: int, If given, only evaluate the dataset at the given
index. index. By default, evaluates dataset at 0'th index.
**kwargs: Additional keyword arguments for configuration override. **kwargs: Additional keyword arguments for configuration override.
""" """
...@@ -950,21 +1030,18 @@ def eval_continuously( ...@@ -950,21 +1030,18 @@ def eval_continuously(
if kwargs['use_bfloat16']: if kwargs['use_bfloat16']:
tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( eval_input_config = eval_input_configs[eval_index]
model_config=model_config, is_training=True) strategy = tf.compat.v2.distribute.get_strategy()
with strategy.scope():
# Create the inputs. detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
eval_inputs = [] model_config=model_config, is_training=True)
for eval_input_config in eval_input_configs:
next_eval_input = inputs.eval_input(
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config,
model=detection_model)
eval_inputs.append((eval_input_config.name, next_eval_input))
if eval_index is not None: eval_input = strategy.experimental_distribute_dataset(
eval_inputs = [eval_inputs[eval_index]] inputs.eval_input(
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config,
model=detection_model))
global_step = tf.compat.v2.Variable( global_step = tf.compat.v2.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64) 0, trainable=False, dtype=tf.compat.v2.dtypes.int64)
...@@ -976,14 +1053,13 @@ def eval_continuously( ...@@ -976,14 +1053,13 @@ def eval_continuously(
ckpt.restore(latest_checkpoint).expect_partial() ckpt.restore(latest_checkpoint).expect_partial()
for eval_name, eval_input in eval_inputs: summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer = tf.compat.v2.summary.create_file_writer( os.path.join(model_dir, 'eval', eval_input_config.name))
os.path.join(model_dir, 'eval', eval_name)) with summary_writer.as_default():
with summary_writer.as_default(): eager_eval_loop(
eager_eval_loop( detection_model,
detection_model, configs,
configs, eval_input,
eval_input, use_tpu=use_tpu,
use_tpu=use_tpu, postprocess_on_cpu=postprocess_on_cpu,
postprocess_on_cpu=postprocess_on_cpu, global_step=global_step)
global_step=global_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment