Commit e9f85e83 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Updating model to infer when labels are padded, rather than relying solely on...

Updating model to infer when labels are padded, rather than relying solely on the mode. This is necessary for evaluating on train data.

PiperOrigin-RevId: 190960173
parent 312b09e2
...@@ -234,7 +234,7 @@ def create_train_input_fn(train_config, train_input_config, ...@@ -234,7 +234,7 @@ def create_train_input_fn(train_config, train_input_config,
features[fields.InputDataFields.true_image_shape] is a [batch_size, 3] features[fields.InputDataFields.true_image_shape] is a [batch_size, 3]
int32 tensor representing the true image shapes, as preprocessed int32 tensor representing the true image shapes, as preprocessed
images could be padded. images could be padded.
features[fields.InputDataFields.image] (optional) is a features[fields.InputDataFields.original_image] (optional) is a
[batch_size, H, W, C] float32 tensor with original images. [batch_size, H, W, C] float32 tensor with original images.
labels: Dictionary of groundtruth tensors. labels: Dictionary of groundtruth tensors.
labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size] labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size]
......
...@@ -225,7 +225,14 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False): ...@@ -225,7 +225,14 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
labels, labels,
unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors) unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
elif mode == tf.estimator.ModeKeys.EVAL: elif mode == tf.estimator.ModeKeys.EVAL:
labels = unstack_batch(labels, unpad_groundtruth_tensors=False) # For evaling on train data, it is necessary to check whether groundtruth
# must be unpadded.
boxes_shape = (
labels[fields.InputDataFields.groundtruth_boxes].get_shape()
.as_list())
unpad_groundtruth_tensors = True if boxes_shape[1] is not None else False
labels = unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes] gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment