Commit 6c9d2eba authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #751 from stef716/resnet_training

Align model slim/resnet to slim/inception
parents f80d631b cb1e6111
...@@ -270,7 +270,7 @@ def inception_v1(inputs, ...@@ -270,7 +270,7 @@ def inception_v1(inputs,
is_training: whether is training or not. is_training: whether is training or not.
dropout_keep_prob: the percentage of activation values that are retained. dropout_keep_prob: the percentage of activation values that are retained.
prediction_fn: a function to get predictions out of logits. prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
......
...@@ -443,7 +443,7 @@ def inception_v2(inputs, ...@@ -443,7 +443,7 @@ def inception_v2(inputs,
usage will be to set this value in (0, 1) to reduce the number of usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model. parameters or computation cost of the model.
prediction_fn: a function to get predictions out of logits. prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
......
...@@ -453,7 +453,7 @@ def inception_v3(inputs, ...@@ -453,7 +453,7 @@ def inception_v3(inputs,
usage will be to set this value in (0, 1) to reduce the number of usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model. parameters or computation cost of the model.
prediction_fn: a function to get predictions out of logits. prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
......
...@@ -119,6 +119,7 @@ def resnet_v1(inputs, ...@@ -119,6 +119,7 @@ def resnet_v1(inputs,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=True,
reuse=None, reuse=None,
scope=None): scope=None):
"""Generator for v1 ResNet models. """Generator for v1 ResNet models.
...@@ -158,6 +159,8 @@ def resnet_v1(inputs, ...@@ -158,6 +159,8 @@ def resnet_v1(inputs,
ratio of input to output spatial resolution. ratio of input to output spatial resolution.
include_root_block: If True, include the initial convolution followed by include_root_block: If True, include the initial convolution followed by
max-pooling, if False excludes it. max-pooling, if False excludes it.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
scope: Optional variable_scope. scope: Optional variable_scope.
...@@ -197,11 +200,13 @@ def resnet_v1(inputs, ...@@ -197,11 +200,13 @@ def resnet_v1(inputs,
if num_classes is not None: if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(net, scope='predictions') end_points['predictions'] = slim.softmax(logits, scope='predictions')
return net, end_points return logits, end_points
resnet_v1.default_image_size = 224 resnet_v1.default_image_size = 224
......
...@@ -117,6 +117,7 @@ def resnet_v2(inputs, ...@@ -117,6 +117,7 @@ def resnet_v2(inputs,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=True,
reuse=None, reuse=None,
scope=None): scope=None):
"""Generator for v2 (preactivation) ResNet models. """Generator for v2 (preactivation) ResNet models.
...@@ -157,6 +158,8 @@ def resnet_v2(inputs, ...@@ -157,6 +158,8 @@ def resnet_v2(inputs,
include_root_block: If True, include the initial convolution followed by include_root_block: If True, include the initial convolution followed by
max-pooling, if False excludes it. If excluded, `inputs` should be the max-pooling, if False excludes it. If excluded, `inputs` should be the
results of an activation-less convolution. results of an activation-less convolution.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
scope: Optional variable_scope. scope: Optional variable_scope.
...@@ -206,11 +209,12 @@ def resnet_v2(inputs, ...@@ -206,11 +209,12 @@ def resnet_v2(inputs,
if num_classes is not None: if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(net, scope='predictions') end_points['predictions'] = slim.softmax(logits, scope='predictions')
return logits, end_points return logits, end_points
resnet_v2.default_image_size = 224 resnet_v2.default_image_size = 224
......
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50
# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers
# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
tar -xvf resnet_v1_50_2016_08_28.tar.gz
mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
rm resnet_v1_50_2016_08_28.tar.gz
fi
# Download the dataset
python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir=${DATASET_DIR}
# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=resnet_v1_50 \
--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
--checkpoint_exclude_scopes=resnet_v1_50/logits \
--trainable_scopes=resnet_v1_50/logits \
--max_number_of_steps=3000 \
--batch_size=32 \
--learning_rate=0.01 \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=100 \
--optimizer=rmsprop \
--weight_decay=0.00004
# Run evaluation.
python eval_image_classifier.py \
--checkpoint_path=${TRAIN_DIR} \
--eval_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=validation \
--dataset_dir=${DATASET_DIR} \
--model_name=resnet_v1_50
# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
--train_dir=${TRAIN_DIR}/all \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--checkpoint_path=${TRAIN_DIR} \
--model_name=resnet_v1_50 \
--max_number_of_steps=1000 \
--batch_size=32 \
--learning_rate=0.001 \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=100 \
--optimizer=rmsprop \
--weight_decay=0.00004
# Run evaluation.
python eval_image_classifier.py \
--checkpoint_path=${TRAIN_DIR}/all \
--eval_dir=${TRAIN_DIR}/all \
--dataset_name=flowers \
--dataset_split_name=validation \
--dataset_dir=${DATASET_DIR} \
--model_name=resnet_v1_50
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment