Commit f0899f18 authored by aquariusjay's avatar aquariusjay Committed by Yukun Zhu
Browse files

Update the codes due to recent change in DeepLab. (#6287)

parent 4b566d4e
......@@ -137,3 +137,27 @@ class VideoModelOptions(common.ModelOptions):
self.classification_loss = FLAGS.classification_loss
return self
def parse_decoder_output_stride():
"""Parses decoder output stride.
FEELVOS assumes decoder_output_stride = 4. Thus, this function is created for
this particular purpose.
Returns:
An integer specifying the decoder_output_stride.
Raises:
ValueError: If decoder_output_stride is None or contains more than one
element.
"""
if FLAGS.decoder_output_stride:
decoder_output_stride = [
int(x) for x in FLAGS.decoder_output_stride]
if len(decoder_output_stride) != 1:
raise ValueError('Expect decoder output stride has only one element.')
decoder_output_stride = decoder_output_stride[0]
else:
raise ValueError('Expect flag decoder output stride not to be None.')
return decoder_output_stride
......@@ -359,9 +359,10 @@ def multi_scale_logits_with_nearest_neighbor_matching(
if model_options.crop_size else tf.shape(images)[2])
# Compute the height, width for the output logits.
logits_output_stride = (
model_options.decoder_output_stride or model_options.output_stride)
if model_options.decoder_output_stride:
logits_output_stride = min(model_options.decoder_output_stride)
else:
logits_output_stride = model_options.output_stride
logits_height = scale_dimension(
crop_height,
max(1.0, max(image_pyramid)) / logits_output_stride)
......
......@@ -266,7 +266,7 @@ def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes,
preceding_frame_label[n, tf.newaxis],
samples[common.LABEL][n * FLAGS.train_num_frames_per_video,
tf.newaxis],
FLAGS.decoder_output_stride,
common.parse_decoder_output_stride(),
reduce_labels=True)
init_softmax_n = tf.squeeze(init_softmax_n, axis=0)
init_softmax.append(init_softmax_n)
......@@ -610,7 +610,7 @@ def _get_dataset_and_samples(config, train_crop_size, dataset_name,
is_training=True,
model_variant=FLAGS.model_variant,
batch_capacity_factor=FLAGS.batch_capacity_factor,
decoder_output_stride=FLAGS.decoder_output_stride,
decoder_output_stride=common.parse_decoder_output_stride(),
first_frame_finetuning=first_frame_finetuning,
sample_only_first_frame_for_finetuning=
FLAGS.sample_only_first_frame_for_finetuning,
......
......@@ -482,20 +482,17 @@ def get_embeddings(images, model_options, embedding_dimension):
is_training=False)
if model_options.decoder_output_stride is not None:
decoder_output_stride = min(model_options.decoder_output_stride)
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = model.scale_dimension(
height, 1.0 / model_options.decoder_output_stride)
decoder_width = model.scale_dimension(
width, 1.0 / model_options.decoder_output_stride)
features = model.refine_by_decoder(
features,
end_points,
decoder_height=decoder_height,
decoder_width=decoder_width,
crop_size=[height, width],
decoder_output_stride=[decoder_output_stride],
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
is_training=False)
......@@ -596,21 +593,20 @@ def get_logits_with_matching(images,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
if model_options.decoder_output_stride is not None:
if model_options.decoder_output_stride:
decoder_output_stride = min(model_options.decoder_output_stride)
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = model.scale_dimension(
height, 1.0 / model_options.decoder_output_stride)
decoder_width = model.scale_dimension(
width, 1.0 / model_options.decoder_output_stride)
decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
features = model.refine_by_decoder(
features,
end_points,
decoder_height=decoder_height,
decoder_width=decoder_width,
crop_size=[height, width],
decoder_output_stride=[decoder_output_stride],
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
weight_decay=weight_decay,
......
......@@ -222,7 +222,7 @@ def create_predictions(samples, reference_labels, first_frame_img,
init_labels = tf.squeeze(reference_labels, axis=-1)
init_softmax = embedding_utils.create_initial_softmax_from_labels(
reference_labels, reference_labels, FLAGS.decoder_output_stride,
reference_labels, reference_labels, common.parse_decoder_output_stride(),
reduce_labels=False)
if FLAGS.save_embeddings:
decoder_height = tf.shape(init_softmax)[1]
......@@ -298,7 +298,7 @@ def create_predictions_fast(samples, reference_labels, first_frame_img,
first_frame_img[tf.newaxis], model_options, FLAGS.embedding_dimension)
init_labels = tf.squeeze(reference_labels, axis=-1)
init_softmax = embedding_utils.create_initial_softmax_from_labels(
reference_labels, reference_labels, FLAGS.decoder_output_stride,
reference_labels, reference_labels, common.parse_decoder_output_stride(),
reduce_labels=False)
init = (init_labels, init_softmax, first_frame_embeddings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment