Commit 6c9d2eba authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #751 from stef716/resnet_training

Align model slim/resnet to slim/inception
parents f80d631b cb1e6111
......@@ -270,7 +270,7 @@ def inception_v1(inputs,
is_training: whether is training or not.
dropout_keep_prob: the percentage of activation values that are retained.
prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
......
......@@ -443,7 +443,7 @@ def inception_v2(inputs,
usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model.
prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
......
......@@ -453,7 +453,7 @@ def inception_v3(inputs,
usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model.
prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
......
......@@ -119,6 +119,7 @@ def resnet_v1(inputs,
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=True,
reuse=None,
scope=None):
"""Generator for v1 ResNet models.
......@@ -158,6 +159,8 @@ def resnet_v1(inputs,
ratio of input to output spatial resolution.
include_root_block: If True, include the initial convolution followed by
max-pooling, if False excludes it.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
......@@ -197,11 +200,13 @@ def resnet_v1(inputs,
if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits')
if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None:
end_points['predictions'] = slim.softmax(net, scope='predictions')
return net, end_points
end_points['predictions'] = slim.softmax(logits, scope='predictions')
return logits, end_points
resnet_v1.default_image_size = 224
......
......@@ -117,6 +117,7 @@ def resnet_v2(inputs,
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=True,
reuse=None,
scope=None):
"""Generator for v2 (preactivation) ResNet models.
......@@ -157,6 +158,8 @@ def resnet_v2(inputs,
include_root_block: If True, include the initial convolution followed by
max-pooling, if False excludes it. If excluded, `inputs` should be the
results of an activation-less convolution.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
......@@ -206,11 +209,12 @@ def resnet_v2(inputs,
if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits')
if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None:
end_points['predictions'] = slim.softmax(net, scope='predictions')
end_points['predictions'] = slim.softmax(logits, scope='predictions')
return logits, end_points
resnet_v2.default_image_size = 224
......
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50
# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers
# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
tar -xvf resnet_v1_50_2016_08_28.tar.gz
mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
rm resnet_v1_50_2016_08_28.tar.gz
fi
# Download the dataset
python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir=${DATASET_DIR}
# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=resnet_v1_50 \
--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
--checkpoint_exclude_scopes=resnet_v1_50/logits \
--trainable_scopes=resnet_v1_50/logits \
--max_number_of_steps=3000 \
--batch_size=32 \
--learning_rate=0.01 \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=100 \
--optimizer=rmsprop \
--weight_decay=0.00004
# Run evaluation.
python eval_image_classifier.py \
--checkpoint_path=${TRAIN_DIR} \
--eval_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=validation \
--dataset_dir=${DATASET_DIR} \
--model_name=resnet_v1_50
# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
--train_dir=${TRAIN_DIR}/all \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--checkpoint_path=${TRAIN_DIR} \
--model_name=resnet_v1_50 \
--max_number_of_steps=1000 \
--batch_size=32 \
--learning_rate=0.001 \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=100 \
--optimizer=rmsprop \
--weight_decay=0.00004
# Run evaluation.
python eval_image_classifier.py \
--checkpoint_path=${TRAIN_DIR}/all \
--eval_dir=${TRAIN_DIR}/all \
--dataset_name=flowers \
--dataset_split_name=validation \
--dataset_dir=${DATASET_DIR} \
--model_name=resnet_v1_50
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment