Commit 13c9de39 authored by Yubin Ruan's avatar Yubin Ruan
Browse files

fix code style problem and add option "last_layers_contain_logits_only"

parent 17ba1ca4
...@@ -13,18 +13,15 @@ ...@@ -13,18 +13,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import glob
import math import math
import os import os
import random import random
import string import string
import sys import sys
from PIL import Image
import build_data import build_data
import tensorflow as tf import tensorflow as tf
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
flags = tf.app.flags
tf.app.flags.DEFINE_string( tf.app.flags.DEFINE_string(
'train_image_folder', 'train_image_folder',
...@@ -52,18 +49,18 @@ tf.app.flags.DEFINE_string( ...@@ -52,18 +49,18 @@ tf.app.flags.DEFINE_string(
_NUM_SHARDS = 4 _NUM_SHARDS = 4
def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir): def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
""" Convert the ADE20k dataset into into tfrecord format (SSTable). """ Converts the ADE20k dataset into into tfrecord format (SSTable).
Args: Args:
dataset_split: dataset split (e.g., train, val) dataset_split: Dataset split (e.g., train, val).
dataset_dir: dir in which the dataset locates dataset_dir: Dir in which the dataset locates.
dataset_label_dir: dir in which the annotations locates dataset_label_dir: Dir in which the annotations locates.
Raises: Raises:
RuntimeError: If loaded image and label have different shape. RuntimeError: If loaded image and label have different shape.
""" """
img_names = glob.glob(os.path.join(dataset_dir, '*.jpg')) img_names = tf.gfile.Glob(os.path.join(dataset_dir, '*.jpg'))
random.shuffle(img_names) random.shuffle(img_names)
seg_names = [] seg_names = []
for f in img_names: for f in img_names:
...@@ -74,7 +71,7 @@ def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir): ...@@ -74,7 +71,7 @@ def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
seg_names.append(seg) seg_names.append(seg)
num_images = len(img_names) num_images = len(img_names)
num_per_shard = int(math.ceil(num_images) / float(_NUM_SHARDS)) num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS)))
image_reader = build_data.ImageReader('jpeg', channels=3) image_reader = build_data.ImageReader('jpeg', channels=3)
label_reader = build_data.ImageReader('png', channels=1) label_reader = build_data.ImageReader('png', channels=1)
......
...@@ -50,7 +50,6 @@ The Example proto contains the following fields: ...@@ -50,7 +50,6 @@ The Example proto contains the following fields:
image/segmentation/class/encoded: encoded semantic segmentation content. image/segmentation/class/encoded: encoded semantic segmentation content.
image/segmentation/class/format: semantic segmentation file format. image/segmentation/class/format: semantic segmentation file format.
""" """
import glob
import math import math
import os.path import os.path
import sys import sys
...@@ -133,7 +132,7 @@ def _convert_dataset(dataset_split): ...@@ -133,7 +132,7 @@ def _convert_dataset(dataset_split):
def main(unused_argv): def main(unused_argv):
dataset_splits = glob.glob(os.path.join(FLAGS.list_folder, '*.txt')) dataset_splits = tf.gfile.Glob(os.path.join(FLAGS.list_folder, '*.txt'))
for dataset_split in dataset_splits: for dataset_split in dataset_splits:
_convert_dataset(dataset_split) _convert_dataset(dataset_split)
......
...@@ -39,27 +39,27 @@ set -e ...@@ -39,27 +39,27 @@ set -e
CURRENT_DIR=$(pwd) CURRENT_DIR=$(pwd)
WORK_DIR="./ADE20K" WORK_DIR="./ADE20K"
mkdir -p ${WORK_DIR} mkdir -p "${WORK_DIR}"
cd ${WORK_DIR} cd "${WORK_DIR}"
# Helper function to download and unpack ADE20K dataset. # Helper function to download and unpack ADE20K dataset.
download_and_uncompress() { download_and_uncompress() {
local BASE_URL=${1} local BASE_URL=${1}
local FILENAME=${2} local FILENAME=${2}
if [ ! -f ${FILENAME} ]; then if [ ! -f "${FILENAME}" ]; then
echo "Downloading ${FILENAME} to ${WORK_DIR}" echo "Downloading ${FILENAME} to ${WORK_DIR}"
wget -nd -c "${BASE_URL}/${FILENAME}" wget -nd -c "${BASE_URL}/${FILENAME}"
fi fi
echo "Uncompressing ${FILENAME}" echo "Uncompressing ${FILENAME}"
unzip ${FILENAME} unzip "${FILENAME}"
} }
# Download the images. # Download the images.
BASE_URL="http://data.csail.mit.edu/places/ADEchallenge" BASE_URL="http://data.csail.mit.edu/places/ADEchallenge"
FILENAME="ADEChallengeData2016.zip" FILENAME="ADEChallengeData2016.zip"
download_and_uncompress ${BASE_URL} ${FILENAME} download_and_uncompress "${BASE_URL}" "${FILENAME}"
cd "${CURRENT_DIR}" cd "${CURRENT_DIR}"
......
...@@ -37,27 +37,27 @@ set -e ...@@ -37,27 +37,27 @@ set -e
CURRENT_DIR=$(pwd) CURRENT_DIR=$(pwd)
WORK_DIR="./pascal_voc_seg" WORK_DIR="./pascal_voc_seg"
mkdir -p ${WORK_DIR} mkdir -p "${WORK_DIR}"
cd ${WORK_DIR} cd "${WORK_DIR}"
# Helper function to download and unpack VOC 2012 dataset. # Helper function to download and unpack VOC 2012 dataset.
download_and_uncompress() { download_and_uncompress() {
local BASE_URL=${1} local BASE_URL=${1}
local FILENAME=${2} local FILENAME=${2}
if [ ! -f ${FILENAME} ]; then if [ ! -f "${FILENAME}" ]; then
echo "Downloading ${FILENAME} to ${WORK_DIR}" echo "Downloading ${FILENAME} to ${WORK_DIR}"
wget -nd -c "${BASE_URL}/${FILENAME}" wget -nd -c "${BASE_URL}/${FILENAME}"
fi fi
echo "Uncompressing ${FILENAME}" echo "Uncompressing ${FILENAME}"
tar -xf ${FILENAME} tar -xf "${FILENAME}"
} }
# Download the images. # Download the images.
BASE_URL="http://host.robots.ox.ac.uk/pascal/VOC/voc2012/" BASE_URL="http://host.robots.ox.ac.uk/pascal/VOC/voc2012/"
FILENAME="VOCtrainval_11-May-2012.tar" FILENAME="VOCtrainval_11-May-2012.tar"
download_and_uncompress ${BASE_URL} ${FILENAME} download_and_uncompress "${BASE_URL}" "${FILENAME}"
cd "${CURRENT_DIR}" cd "${CURRENT_DIR}"
......
...@@ -31,6 +31,11 @@ images for the training, validation and test respectively. ...@@ -31,6 +31,11 @@ images for the training, validation and test respectively.
The Cityscapes dataset contains 19 semantic labels (such as road, person, car, The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
and so on) for urban street scenes. and so on) for urban street scenes.
3. ADE20K dataset (http://groups.csail.mit.edu/vision/datasets/ADE20K)
The ADE20K dataset contains 150 semantic labels both urban street scenes and
indoor scenes.
References: References:
M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn, M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
and A. Zisserman, The pascal visual object classes challenge a retrospective. and A. Zisserman, The pascal visual object classes challenge a retrospective.
...@@ -39,6 +44,9 @@ References: ...@@ -39,6 +44,9 @@ References:
M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson, M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
scene understanding," In Proc. of CVPR, 2016. scene understanding," In Proc. of CVPR, 2016.
B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso, A. Torralba, "Scene Parsing
through ADE20K dataset", In Proc. of CVPR, 2017.
""" """
import collections import collections
import os.path import os.path
...@@ -87,12 +95,10 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor( ...@@ -87,12 +95,10 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
# These number (i.e., 'train'/'test') seems to have to be hard coded # These number (i.e., 'train'/'test') seems to have to be hard coded
# You are required to figure it out for your training/testing example. # You are required to figure it out for your training/testing example.
# Is there a way to automatically figure it out ?
_ADE20K_INFORMATION = DatasetDescriptor( _ADE20K_INFORMATION = DatasetDescriptor(
splits_to_sizes = { splits_to_sizes = {
'train': 20210, # num of samples in images/training 'train': 20210, # num of samples in images/training
'val': 2000, # num of samples in images/validation 'val': 2000, # num of samples in images/validation
'eval': 2,
}, },
num_classes=150, num_classes=150,
ignore_label=255, ignore_label=255,
......
...@@ -67,6 +67,7 @@ python deeplab/train.py \ ...@@ -67,6 +67,7 @@ python deeplab/train.py \
--fine_tune_batch_norm=False \ --fine_tune_batch_norm=False \
--dataset="ade20k" \ --dataset="ade20k" \
--initialize_last_layer=False \ --initialize_last_layer=False \
--last_layers_contain_logits_only=True \
--tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \ --tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
--train_logdir=${PATH_TO_TRAIN_DIR}\ --train_logdir=${PATH_TO_TRAIN_DIR}\
--dataset_dir=${PATH_TO_DATASET} --dataset_dir=${PATH_TO_DATASET}
...@@ -90,7 +91,7 @@ which the ADE20K dataset resides (the `tfrecord` above) ...@@ -90,7 +91,7 @@ which the ADE20K dataset resides (the `tfrecord` above)
fine_tune_batch_norm = False. fine_tune_batch_norm = False.
2. User should fine tune the `min_resize_value` and `max_resize_value` to get 2. User should fine tune the `min_resize_value` and `max_resize_value` to get
better result. Note that `resize_factor` has to equals to `output_stride`. better result. Note that `resize_factor` has to be equal to `output_stride`.
2. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if 2. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if
setting output_stride=8. setting output_stride=8.
......
...@@ -64,19 +64,26 @@ _CONCAT_PROJECTION_SCOPE = 'concat_projection' ...@@ -64,19 +64,26 @@ _CONCAT_PROJECTION_SCOPE = 'concat_projection'
_DECODER_SCOPE = 'decoder' _DECODER_SCOPE = 'decoder'
def get_extra_layer_scopes(): def get_extra_layer_scopes(last_layers_contain_logits_only=False):
"""Gets the scopes for extra layers. """Gets the scopes for extra layers.
Args:
last_layers_contain_logits_only: Boolean, True if only consider logits as
the last layer (i.e., exclude ASPP module, decoder module and so on)
Returns: Returns:
A list of scopes for extra layers. A list of scopes for extra layers.
""" """
return [ if last_layers_contain_logits_only:
_LOGITS_SCOPE_NAME, return [_LOGITS_SCOPE_NAME]
_IMAGE_POOLING_SCOPE, else:
_ASPP_SCOPE, return [
_CONCAT_PROJECTION_SCOPE, _LOGITS_SCOPE_NAME,
_DECODER_SCOPE, _IMAGE_POOLING_SCOPE,
] _ASPP_SCOPE,
_CONCAT_PROJECTION_SCOPE,
_DECODER_SCOPE,
]
def predict_labels_multi_scale(images, def predict_labels_multi_scale(images,
......
...@@ -118,6 +118,9 @@ flags.DEFINE_string('tf_initial_checkpoint', None, ...@@ -118,6 +118,9 @@ flags.DEFINE_string('tf_initial_checkpoint', None,
flags.DEFINE_boolean('initialize_last_layer', True, flags.DEFINE_boolean('initialize_last_layer', True,
'Initialize the last layer.') 'Initialize the last layer.')
flags.DEFINE_boolean('last_layers_contain_logits_only', False,
'Only consider logits as last layers or not.')
flags.DEFINE_integer('slow_start_step', 0, flags.DEFINE_integer('slow_start_step', 0,
'Training model with small learning rate for few steps.') 'Training model with small learning rate for few steps.')
...@@ -292,7 +295,7 @@ def main(unused_argv): ...@@ -292,7 +295,7 @@ def main(unused_argv):
summaries.add(tf.summary.scalar('total_loss', total_loss)) summaries.add(tf.summary.scalar('total_loss', total_loss))
# Modify the gradients for biases and last layer variables. # Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes() last_layers = model.get_extra_layer_scopes(FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers( grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier) last_layers, FLAGS.last_layer_gradient_multiplier)
if grad_mult: if grad_mult:
......
...@@ -99,7 +99,7 @@ def get_model_init_fn(train_logdir, ...@@ -99,7 +99,7 @@ def get_model_init_fn(train_logdir,
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint) tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored. # Variables that will not be restored.
exclude_list = ['global_step', 'logits'] exclude_list = ['global_step']
if not initialize_last_layer: if not initialize_last_layer:
exclude_list.extend(last_layers) exclude_list.extend(last_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment