Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4f9d1024
Commit
4f9d1024
authored
Sep 08, 2016
by
Chris Shallue
Browse files
Open source the image-to-text model based on the "Show and Tell" paper.
parent
54886315
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1232 additions
and
0 deletions
+1232
-0
im2txt/im2txt/ops/image_embedding_test.py
im2txt/im2txt/ops/image_embedding_test.py
+136
-0
im2txt/im2txt/ops/image_processing.py
im2txt/im2txt/ops/image_processing.py
+134
-0
im2txt/im2txt/ops/inputs.py
im2txt/im2txt/ops/inputs.py
+204
-0
im2txt/im2txt/run_inference.py
im2txt/im2txt/run_inference.py
+83
-0
im2txt/im2txt/show_and_tell_model.py
im2txt/im2txt/show_and_tell_model.py
+364
-0
im2txt/im2txt/show_and_tell_model_test.py
im2txt/im2txt/show_and_tell_model_test.py
+200
-0
im2txt/im2txt/train.py
im2txt/im2txt/train.py
+111
-0
No files found.
im2txt/im2txt/ops/image_embedding_test.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.im2txt.ops.image_embedding."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
im2txt.ops
import
image_embedding
class
InceptionV3Test
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
InceptionV3Test
,
self
).
setUp
()
batch_size
=
4
height
=
299
width
=
299
num_channels
=
3
self
.
_images
=
tf
.
placeholder
(
tf
.
float32
,
[
batch_size
,
height
,
width
,
num_channels
])
self
.
_batch_size
=
batch_size
def
_countInceptionParameters
(
self
):
"""Counts the number of parameters in the inception model at top scope."""
counter
=
{}
for
v
in
tf
.
all_variables
():
name_tokens
=
v
.
op
.
name
.
split
(
"/"
)
if
name_tokens
[
0
]
==
"InceptionV3"
:
name
=
"InceptionV3/"
+
name_tokens
[
1
]
num_params
=
v
.
get_shape
().
num_elements
()
assert
num_params
counter
[
name
]
=
counter
.
get
(
name
,
0
)
+
num_params
return
counter
def
_verifyParameterCounts
(
self
):
"""Verifies the number of parameters in the inception model."""
param_counts
=
self
.
_countInceptionParameters
()
expected_param_counts
=
{
"InceptionV3/Conv2d_1a_3x3"
:
960
,
"InceptionV3/Conv2d_2a_3x3"
:
9312
,
"InceptionV3/Conv2d_2b_3x3"
:
18624
,
"InceptionV3/Conv2d_3b_1x1"
:
5360
,
"InceptionV3/Conv2d_4a_3x3"
:
138816
,
"InceptionV3/Mixed_5b"
:
256368
,
"InceptionV3/Mixed_5c"
:
277968
,
"InceptionV3/Mixed_5d"
:
285648
,
"InceptionV3/Mixed_6a"
:
1153920
,
"InceptionV3/Mixed_6b"
:
1298944
,
"InceptionV3/Mixed_6c"
:
1692736
,
"InceptionV3/Mixed_6d"
:
1692736
,
"InceptionV3/Mixed_6e"
:
2143872
,
"InceptionV3/Mixed_7a"
:
1699584
,
"InceptionV3/Mixed_7b"
:
5047872
,
"InceptionV3/Mixed_7c"
:
6080064
,
}
self
.
assertDictEqual
(
expected_param_counts
,
param_counts
)
def
_assertCollectionSize
(
self
,
expected_size
,
collection
):
actual_size
=
len
(
tf
.
get_collection
(
collection
))
if
expected_size
!=
actual_size
:
self
.
fail
(
"Found %d items in collection %s (expected %d)."
%
(
actual_size
,
collection
,
expected_size
))
def
testTrainableTrueIsTrainingTrue
(
self
):
embeddings
=
image_embedding
.
inception_v3
(
self
.
_images
,
trainable
=
True
,
is_training
=
True
)
self
.
assertEqual
([
self
.
_batch_size
,
2048
],
embeddings
.
get_shape
().
as_list
())
self
.
_verifyParameterCounts
()
self
.
_assertCollectionSize
(
376
,
tf
.
GraphKeys
.
VARIABLES
)
self
.
_assertCollectionSize
(
188
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
self
.
_assertCollectionSize
(
188
,
tf
.
GraphKeys
.
UPDATE_OPS
)
self
.
_assertCollectionSize
(
94
,
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
LOSSES
)
self
.
_assertCollectionSize
(
23
,
tf
.
GraphKeys
.
SUMMARIES
)
def
testTrainableTrueIsTrainingFalse
(
self
):
embeddings
=
image_embedding
.
inception_v3
(
self
.
_images
,
trainable
=
True
,
is_training
=
False
)
self
.
assertEqual
([
self
.
_batch_size
,
2048
],
embeddings
.
get_shape
().
as_list
())
self
.
_verifyParameterCounts
()
self
.
_assertCollectionSize
(
376
,
tf
.
GraphKeys
.
VARIABLES
)
self
.
_assertCollectionSize
(
188
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
UPDATE_OPS
)
self
.
_assertCollectionSize
(
94
,
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
LOSSES
)
self
.
_assertCollectionSize
(
23
,
tf
.
GraphKeys
.
SUMMARIES
)
def
testTrainableFalseIsTrainingTrue
(
self
):
embeddings
=
image_embedding
.
inception_v3
(
self
.
_images
,
trainable
=
False
,
is_training
=
True
)
self
.
assertEqual
([
self
.
_batch_size
,
2048
],
embeddings
.
get_shape
().
as_list
())
self
.
_verifyParameterCounts
()
self
.
_assertCollectionSize
(
376
,
tf
.
GraphKeys
.
VARIABLES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
UPDATE_OPS
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
LOSSES
)
self
.
_assertCollectionSize
(
23
,
tf
.
GraphKeys
.
SUMMARIES
)
def
testTrainableFalseIsTrainingFalse
(
self
):
embeddings
=
image_embedding
.
inception_v3
(
self
.
_images
,
trainable
=
False
,
is_training
=
False
)
self
.
assertEqual
([
self
.
_batch_size
,
2048
],
embeddings
.
get_shape
().
as_list
())
self
.
_verifyParameterCounts
()
self
.
_assertCollectionSize
(
376
,
tf
.
GraphKeys
.
VARIABLES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
UPDATE_OPS
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
self
.
_assertCollectionSize
(
0
,
tf
.
GraphKeys
.
LOSSES
)
self
.
_assertCollectionSize
(
23
,
tf
.
GraphKeys
.
SUMMARIES
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
im2txt/im2txt/ops/image_processing.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for image preprocessing."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
distort_image
(
image
,
thread_id
):
"""Perform random distortions on an image.
Args:
image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
thread_id: Preprocessing thread id used to select the ordering of color
distortions. There should be a multiple of 2 preprocessing threads.
Returns:
distorted_image: A float32 Tensor of shape [height, width, 3] with values in
[0, 1].
"""
# Randomly flip horizontally.
with
tf
.
name_scope
(
"flip_horizontal"
,
values
=
[
image
]):
image
=
tf
.
image
.
random_flip_left_right
(
image
)
# Randomly distort the colors based on thread id.
color_ordering
=
thread_id
%
2
with
tf
.
name_scope
(
"distort_color"
,
values
=
[
image
]):
if
color_ordering
==
0
:
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
32.
/
255.
)
image
=
tf
.
image
.
random_saturation
(
image
,
lower
=
0.5
,
upper
=
1.5
)
image
=
tf
.
image
.
random_hue
(
image
,
max_delta
=
0.032
)
image
=
tf
.
image
.
random_contrast
(
image
,
lower
=
0.5
,
upper
=
1.5
)
elif
color_ordering
==
1
:
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
32.
/
255.
)
image
=
tf
.
image
.
random_contrast
(
image
,
lower
=
0.5
,
upper
=
1.5
)
image
=
tf
.
image
.
random_saturation
(
image
,
lower
=
0.5
,
upper
=
1.5
)
image
=
tf
.
image
.
random_hue
(
image
,
max_delta
=
0.032
)
# The random_* ops do not necessarily clamp.
image
=
tf
.
clip_by_value
(
image
,
0.0
,
1.0
)
return
image
def
process_image
(
encoded_image
,
is_training
,
height
,
width
,
resize_height
=
346
,
resize_width
=
346
,
thread_id
=
0
,
image_format
=
"jpeg"
):
"""Decode an image, resize and apply random distortions.
In training, images are distorted slightly differently depending on thread_id.
Args:
encoded_image: String Tensor containing the image.
is_training: Boolean; whether preprocessing for training or eval.
height: Height of the output image.
width: Width of the output image.
resize_height: If > 0, resize height before crop to final dimensions.
resize_width: If > 0, resize width before crop to final dimensions.
thread_id: Preprocessing thread id used to select the ordering of color
distortions. There should be a multiple of 2 preprocessing threads.
image_format: "jpeg" or "png".
Returns:
A float32 Tensor of shape [height, width, 3] with values in [-1, 1].
Raises:
ValueError: If image_format is invalid.
"""
# Helper function to log an image summary to the visualizer. Summaries are
# only logged in thread 0.
def
image_summary
(
name
,
image
):
if
not
thread_id
:
tf
.
image_summary
(
name
,
tf
.
expand_dims
(
image
,
0
))
# Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
with
tf
.
name_scope
(
"decode"
,
values
=
[
encoded_image
]):
if
image_format
==
"jpeg"
:
image
=
tf
.
image
.
decode_jpeg
(
encoded_image
,
channels
=
3
)
elif
image_format
==
"png"
:
image
=
tf
.
image
.
decode_png
(
encoded_image
,
channels
=
3
)
else
:
raise
ValueError
(
"Invalid image format: %s"
%
image_format
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
image_summary
(
"original_image"
,
image
)
# Resize image.
assert
(
resize_height
>
0
)
==
(
resize_width
>
0
)
if
resize_height
:
image
=
tf
.
image
.
resize_images
(
image
,
new_height
=
resize_height
,
new_width
=
resize_width
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
# Crop to final dimensions.
if
is_training
:
image
=
tf
.
random_crop
(
image
,
[
height
,
width
,
3
])
else
:
# Central crop, assuming resize_height > height, resize_width > width.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
height
,
width
)
image_summary
(
"resized_image"
,
image
)
# Randomly distort the image.
if
is_training
:
image
=
distort_image
(
image
,
thread_id
)
image_summary
(
"final_image"
,
image
)
# Rescale to [-1,1] instead of [0, 1]
image
=
tf
.
sub
(
image
,
0.5
)
image
=
tf
.
mul
(
image
,
2.0
)
return
image
im2txt/im2txt/ops/inputs.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input ops."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
parse_sequence_example
(
serialized
,
image_feature
,
caption_feature
):
"""Parses a tensorflow.SequenceExample into an image and caption.
Args:
serialized: A scalar string Tensor; a single serialized SequenceExample.
image_feature: Name of SequenceExample context feature containing image
data.
caption_feature: Name of SequenceExample feature list containing integer
captions.
Returns:
encoded_image: A scalar string Tensor containing a JPEG encoded image.
caption: A 1-D uint64 Tensor with dynamically specified length.
"""
context
,
sequence
=
tf
.
parse_single_sequence_example
(
serialized
,
context_features
=
{
image_feature
:
tf
.
FixedLenFeature
([],
dtype
=
tf
.
string
)
},
sequence_features
=
{
caption_feature
:
tf
.
FixedLenSequenceFeature
([],
dtype
=
tf
.
int64
),
})
encoded_image
=
context
[
image_feature
]
caption
=
sequence
[
caption_feature
]
return
encoded_image
,
caption
def
prefetch_input_data
(
reader
,
file_pattern
,
is_training
,
batch_size
,
values_per_shard
,
input_queue_capacity_factor
=
16
,
num_reader_threads
=
1
,
shard_queue_name
=
"filename_queue"
,
value_queue_name
=
"input_queue"
):
"""Prefetches string values from disk into an input queue.
In training the capacity of the queue is important because a larger queue
means better mixing of training examples between shards. The minimum number of
values kept in the queue is values_per_shard * input_queue_capacity_factor,
where input_queue_memory factor should be chosen to trade-off better mixing
with memory usage.
Args:
reader: Instance of tf.ReaderBase.
file_pattern: Comma-separated list of file patterns (e.g.
/tmp/train_data-?????-of-00100).
is_training: Boolean; whether prefetching for training or eval.
batch_size: Model batch size used to determine queue capacity.
values_per_shard: Approximate number of values per shard.
input_queue_capacity_factor: Minimum number of values to keep in the queue
in multiples of values_per_shard. See comments above.
num_reader_threads: Number of reader threads to fill the queue.
shard_queue_name: Name for the shards filename queue.
value_queue_name: Name for the values input queue.
Returns:
A Queue containing prefetched string values.
"""
data_files
=
[]
for
pattern
in
file_pattern
.
split
(
","
):
data_files
.
extend
(
tf
.
gfile
.
Glob
(
pattern
))
if
not
data_files
:
tf
.
logging
.
fatal
(
"Found no input files matching %s"
,
file_pattern
)
else
:
tf
.
logging
.
info
(
"Prefetching values from %d files matching %s"
,
len
(
data_files
),
file_pattern
)
if
is_training
:
filename_queue
=
tf
.
train
.
string_input_producer
(
data_files
,
shuffle
=
True
,
capacity
=
16
,
name
=
shard_queue_name
)
min_queue_examples
=
values_per_shard
*
input_queue_capacity_factor
capacity
=
min_queue_examples
+
100
*
batch_size
values_queue
=
tf
.
RandomShuffleQueue
(
capacity
=
capacity
,
min_after_dequeue
=
min_queue_examples
,
dtypes
=
[
tf
.
string
],
name
=
"random_"
+
value_queue_name
)
else
:
filename_queue
=
tf
.
train
.
string_input_producer
(
data_files
,
shuffle
=
False
,
capacity
=
1
,
name
=
shard_queue_name
)
capacity
=
values_per_shard
+
3
*
batch_size
values_queue
=
tf
.
FIFOQueue
(
capacity
=
capacity
,
dtypes
=
[
tf
.
string
],
name
=
"fifo_"
+
value_queue_name
)
enqueue_ops
=
[]
for
_
in
range
(
num_reader_threads
):
_
,
value
=
reader
.
read
(
filename_queue
)
enqueue_ops
.
append
(
values_queue
.
enqueue
([
value
]))
tf
.
train
.
queue_runner
.
add_queue_runner
(
tf
.
train
.
queue_runner
.
QueueRunner
(
values_queue
,
enqueue_ops
))
tf
.
scalar_summary
(
"queue/%s/fraction_of_%d_full"
%
(
values_queue
.
name
,
capacity
),
tf
.
cast
(
values_queue
.
size
(),
tf
.
float32
)
*
(
1.
/
capacity
))
return
values_queue
def
batch_with_dynamic_pad
(
images_and_captions
,
batch_size
,
queue_capacity
,
add_summaries
=
True
):
"""Batches input images and captions.
This function splits the caption into an input sequence and a target sequence,
where the target sequence is the input sequence right-shifted by 1. Input and
target sequences are batched and padded up to the maximum length of sequences
in the batch. A mask is created to distinguish real words from padding words.
Example:
Actual captions in the batch ('-' denotes padded character):
[
[ 1 2 5 4 5 ],
[ 1 2 3 4 - ],
[ 1 2 3 - - ],
]
input_seqs:
[
[ 1 2 3 4 ],
[ 1 2 3 - ],
[ 1 2 - - ],
]
target_seqs:
[
[ 2 3 4 5 ],
[ 2 3 4 - ],
[ 2 3 - - ],
]
mask:
[
[ 1 1 1 1 ],
[ 1 1 1 0 ],
[ 1 1 0 0 ],
]
Args:
images_and_captions: A list of pairs [image, caption], where image is a
Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
any length. Each pair will be processed and added to the queue in a
separate thread.
batch_size: Batch size.
queue_capacity: Queue capacity.
add_summaries: If true, add caption length summaries.
Returns:
images: A Tensor of shape [batch_size, height, width, channels].
input_seqs: An int32 Tensor of shape [batch_size, padded_length].
target_seqs: An int32 Tensor of shape [batch_size, padded_length].
mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
"""
enqueue_list
=
[]
for
image
,
caption
in
images_and_captions
:
caption_length
=
tf
.
shape
(
caption
)[
0
]
input_length
=
tf
.
expand_dims
(
tf
.
sub
(
caption_length
,
1
),
0
)
input_seq
=
tf
.
slice
(
caption
,
[
0
],
input_length
)
target_seq
=
tf
.
slice
(
caption
,
[
1
],
input_length
)
indicator
=
tf
.
ones
(
input_length
,
dtype
=
tf
.
int32
)
enqueue_list
.
append
([
image
,
input_seq
,
target_seq
,
indicator
])
images
,
input_seqs
,
target_seqs
,
mask
=
tf
.
train
.
batch_join
(
enqueue_list
,
batch_size
=
batch_size
,
capacity
=
queue_capacity
,
dynamic_pad
=
True
,
name
=
"batch_and_pad"
)
if
add_summaries
:
lengths
=
tf
.
add
(
tf
.
reduce_sum
(
mask
,
1
),
1
)
tf
.
scalar_summary
(
"caption_length/batch_min"
,
tf
.
reduce_min
(
lengths
))
tf
.
scalar_summary
(
"caption_length/batch_max"
,
tf
.
reduce_max
(
lengths
))
tf
.
scalar_summary
(
"caption_length/batch_mean"
,
tf
.
reduce_mean
(
lengths
))
return
images
,
input_seqs
,
target_seqs
,
mask
im2txt/im2txt/run_inference.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Generate captions for images using default beam search parameters."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
os
import
tensorflow
as
tf
from
im2txt
import
configuration
from
im2txt
import
inference_wrapper
from
im2txt.inference_utils
import
caption_generator
from
im2txt.inference_utils
import
vocabulary
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
"checkpoint_path"
,
""
,
"Model checkpoint file or directory containing a "
"model checkpoint file."
)
tf
.
flags
.
DEFINE_string
(
"vocab_file"
,
""
,
"Text file containing the vocabulary."
)
tf
.
flags
.
DEFINE_string
(
"input_files"
,
""
,
"File pattern or comma-separated list of file patterns "
"of image files."
)
def
main
(
_
):
# Build the inference graph.
g
=
tf
.
Graph
()
with
g
.
as_default
():
model
=
inference_wrapper
.
InferenceWrapper
()
restore_fn
=
model
.
build_graph_from_config
(
configuration
.
ModelConfig
(),
FLAGS
.
checkpoint_path
)
g
.
finalize
()
# Create the vocabulary.
vocab
=
vocabulary
.
Vocabulary
(
FLAGS
.
vocab_file
)
filenames
=
[]
for
file_pattern
in
FLAGS
.
input_files
.
split
(
","
):
filenames
.
extend
(
tf
.
gfile
.
Glob
(
file_pattern
))
tf
.
logging
.
info
(
"Running caption generation on %d files matching %s"
,
len
(
filenames
),
FLAGS
.
input_files
)
with
tf
.
Session
(
graph
=
g
)
as
sess
:
# Load the model from checkpoint.
restore_fn
(
sess
)
# Prepare the caption generator. Here we are implicitly using the default
# beam search parameters. See caption_generator.py for a description of the
# available beam search parameters.
generator
=
caption_generator
.
CaptionGenerator
(
model
,
vocab
)
for
filename
in
filenames
:
with
tf
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
image
=
f
.
read
()
captions
=
generator
.
beam_search
(
sess
,
image
)
print
(
"Captions for image %s:"
%
os
.
path
.
basename
(
filename
))
for
i
,
caption
in
enumerate
(
captions
):
# Ignore begin and end words.
sentence
=
[
vocab
.
id_to_word
(
w
)
for
w
in
caption
.
sentence
[
1
:
-
1
]]
sentence
=
" "
.
join
(
sentence
)
print
(
" %d) %s (p=%f)"
%
(
i
,
sentence
,
math
.
exp
(
caption
.
logprob
)))
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
im2txt/im2txt/show_and_tell_model.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
im2txt.ops
import
image_embedding
from
im2txt.ops
import
image_processing
from
im2txt.ops
import
inputs
as
input_ops
class
ShowAndTellModel
(
object
):
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
def
__init__
(
self
,
config
,
mode
,
train_inception
=
False
):
"""Basic setup.
Args:
config: Object containing configuration parameters.
mode: "train", "eval" or "inference".
train_inception: Whether the inception submodel variables are trainable.
"""
assert
mode
in
[
"train"
,
"eval"
,
"inference"
]
self
.
config
=
config
self
.
mode
=
mode
self
.
train_inception
=
train_inception
# Reader for the input data.
self
.
reader
=
tf
.
TFRecordReader
()
# To match the "Show and Tell" paper we initialize all variables with a
# random uniform initializer.
self
.
initializer
=
tf
.
random_uniform_initializer
(
minval
=-
self
.
config
.
initializer_scale
,
maxval
=
self
.
config
.
initializer_scale
)
# A float32 Tensor with shape [batch_size, height, width, channels].
self
.
images
=
None
# An int32 Tensor with shape [batch_size, padded_length].
self
.
input_seqs
=
None
# An int32 Tensor with shape [batch_size, padded_length].
self
.
target_seqs
=
None
# An int32 0/1 Tensor with shape [batch_size, padded_length].
self
.
input_mask
=
None
# A float32 Tensor with shape [batch_size, embedding_size].
self
.
image_embeddings
=
None
# A float32 Tensor with shape [batch_size, padded_length, embedding_size].
self
.
seq_embeddings
=
None
# A float32 scalar Tensor; the total loss for the trainer to optimize.
self
.
total_loss
=
None
# A float32 Tensor with shape [batch_size * padded_length].
self
.
target_cross_entropy_losses
=
None
# A float32 Tensor with shape [batch_size * padded_length].
self
.
target_cross_entropy_loss_weights
=
None
# Collection of variables from the inception submodel.
self
.
inception_variables
=
[]
# Function to restore the inception submodel from checkpoint.
self
.
init_fn
=
None
# Global step Tensor.
self
.
global_step
=
None
def
is_training
(
self
):
"""Returns true if the model is built for training mode."""
return
self
.
mode
==
"train"
def
process_image
(
self
,
encoded_image
,
thread_id
=
0
):
"""Decodes and processes an image string.
Args:
encoded_image: A scalar string Tensor; the encoded image.
thread_id: Preprocessing thread id used to select the ordering of color
distortions.
Returns:
A float32 Tensor of shape [height, width, 3]; the processed image.
"""
return
image_processing
.
process_image
(
encoded_image
,
is_training
=
self
.
is_training
(),
height
=
self
.
config
.
image_height
,
width
=
self
.
config
.
image_width
,
thread_id
=
thread_id
,
image_format
=
self
.
config
.
image_format
)
def
build_inputs
(
self
):
"""Input prefetching, preprocessing and batching.
Outputs:
self.images
self.input_seqs
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
"""
if
self
.
mode
==
"inference"
:
# In inference mode, images and inputs are fed via placeholders.
image_feed
=
tf
.
placeholder
(
dtype
=
tf
.
string
,
shape
=
[],
name
=
"image_feed"
)
input_feed
=
tf
.
placeholder
(
dtype
=
tf
.
int64
,
shape
=
[
None
],
# batch_size
name
=
"input_feed"
)
# Process image and insert batch dimensions.
images
=
tf
.
expand_dims
(
self
.
process_image
(
image_feed
),
0
)
input_seqs
=
tf
.
expand_dims
(
input_feed
,
1
)
# No target sequences or input mask in inference mode.
target_seqs
=
None
input_mask
=
None
else
:
# Prefetch serialized SequenceExample protos.
input_queue
=
input_ops
.
prefetch_input_data
(
self
.
reader
,
self
.
config
.
input_file_pattern
,
is_training
=
self
.
is_training
(),
batch_size
=
self
.
config
.
batch_size
,
values_per_shard
=
self
.
config
.
values_per_input_shard
,
input_queue_capacity_factor
=
self
.
config
.
input_queue_capacity_factor
,
num_reader_threads
=
self
.
config
.
num_input_reader_threads
)
# Image processing and random distortion. Split across multiple threads
# with each thread applying a slightly different distortion.
assert
self
.
config
.
num_preprocess_threads
%
2
==
0
images_and_captions
=
[]
for
thread_id
in
range
(
self
.
config
.
num_preprocess_threads
):
serialized_sequence_example
=
input_queue
.
dequeue
()
encoded_image
,
caption
=
input_ops
.
parse_sequence_example
(
serialized_sequence_example
,
image_feature
=
self
.
config
.
image_feature_name
,
caption_feature
=
self
.
config
.
caption_feature_name
)
image
=
self
.
process_image
(
encoded_image
,
thread_id
=
thread_id
)
images_and_captions
.
append
([
image
,
caption
])
# Batch inputs.
queue_capacity
=
(
2
*
self
.
config
.
num_preprocess_threads
*
self
.
config
.
batch_size
)
images
,
input_seqs
,
target_seqs
,
input_mask
=
(
input_ops
.
batch_with_dynamic_pad
(
images_and_captions
,
batch_size
=
self
.
config
.
batch_size
,
queue_capacity
=
queue_capacity
))
self
.
images
=
images
self
.
input_seqs
=
input_seqs
self
.
target_seqs
=
target_seqs
self
.
input_mask
=
input_mask
def
build_image_embeddings
(
self
):
"""Builds the image model subgraph and generates image embeddings.
Inputs:
self.images
Outputs:
self.image_embeddings
"""
inception_output
=
image_embedding
.
inception_v3
(
self
.
images
,
trainable
=
self
.
train_inception
,
is_training
=
self
.
is_training
())
self
.
inception_variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
,
scope
=
"InceptionV3"
)
# Map inception output into embedding space.
with
tf
.
variable_scope
(
"image_embedding"
)
as
scope
:
image_embeddings
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
inception_output
,
num_outputs
=
self
.
config
.
embedding_size
,
activation_fn
=
None
,
weights_initializer
=
self
.
initializer
,
biases_initializer
=
None
,
scope
=
scope
)
# Save the embedding size in the graph.
tf
.
constant
(
self
.
config
.
embedding_size
,
name
=
"embedding_size"
)
self
.
image_embeddings
=
image_embeddings
def
build_seq_embeddings
(
self
):
"""Builds the input sequence embeddings.
Inputs:
self.input_seqs
Outputs:
self.seq_embeddings
"""
with
tf
.
variable_scope
(
"seq_embedding"
),
tf
.
device
(
"/cpu:0"
):
embedding_map
=
tf
.
get_variable
(
name
=
"map"
,
shape
=
[
self
.
config
.
vocab_size
,
self
.
config
.
embedding_size
],
initializer
=
self
.
initializer
)
seq_embeddings
=
tf
.
nn
.
embedding_lookup
(
embedding_map
,
self
.
input_seqs
)
self
.
seq_embeddings
=
seq_embeddings
def
build_model
(
self
):
"""Builds the model.
Inputs:
self.image_embeddings
self.seq_embeddings
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
Outputs:
self.total_loss (training and eval only)
self.target_cross_entropy_losses (training and eval only)
self.target_cross_entropy_loss_weights (training and eval only)
"""
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
# modified LSTM in the "Show and Tell" paper has no biases and outputs
# new_c * sigmoid(o).
lstm_cell
=
tf
.
nn
.
rnn_cell
.
BasicLSTMCell
(
num_units
=
self
.
config
.
num_lstm_units
,
state_is_tuple
=
True
)
if
self
.
mode
==
"train"
:
lstm_cell
=
tf
.
nn
.
rnn_cell
.
DropoutWrapper
(
lstm_cell
,
input_keep_prob
=
self
.
config
.
lstm_dropout_keep_prob
,
output_keep_prob
=
self
.
config
.
lstm_dropout_keep_prob
)
with
tf
.
variable_scope
(
"lstm"
,
initializer
=
self
.
initializer
)
as
lstm_scope
:
# Feed the image embeddings to set the initial LSTM state.
zero_state
=
lstm_cell
.
zero_state
(
batch_size
=
self
.
image_embeddings
.
get_shape
()[
0
],
dtype
=
tf
.
float32
)
_
,
initial_state
=
lstm_cell
(
self
.
image_embeddings
,
zero_state
)
# Allow the LSTM variables to be reused.
lstm_scope
.
reuse_variables
()
if
self
.
mode
==
"inference"
:
# In inference mode, use concatenated states for convenient feeding and
# fetching.
tf
.
concat
(
1
,
initial_state
,
name
=
"initial_state"
)
# Placeholder for feeding a batch of concatenated states.
state_feed
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
sum
(
lstm_cell
.
state_size
)],
name
=
"state_feed"
)
state_tuple
=
tf
.
split
(
1
,
2
,
state_feed
)
# Run a single LSTM step.
lstm_outputs
,
state_tuple
=
lstm_cell
(
inputs
=
tf
.
squeeze
(
self
.
seq_embeddings
,
squeeze_dims
=
[
1
]),
state
=
state_tuple
)
# Concatentate the resulting state.
tf
.
concat
(
1
,
state_tuple
,
name
=
"state"
)
else
:
# Run the batch of sequence embeddings through the LSTM.
sequence_length
=
tf
.
reduce_sum
(
self
.
input_mask
,
1
)
lstm_outputs
,
_
=
tf
.
nn
.
dynamic_rnn
(
cell
=
lstm_cell
,
inputs
=
self
.
seq_embeddings
,
sequence_length
=
sequence_length
,
initial_state
=
initial_state
,
dtype
=
tf
.
float32
,
scope
=
lstm_scope
)
# Stack batches vertically.
lstm_outputs
=
tf
.
reshape
(
lstm_outputs
,
[
-
1
,
lstm_cell
.
output_size
])
with
tf
.
variable_scope
(
"logits"
)
as
logits_scope
:
logits
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
lstm_outputs
,
num_outputs
=
self
.
config
.
vocab_size
,
activation_fn
=
None
,
weights_initializer
=
self
.
initializer
,
scope
=
logits_scope
)
if
self
.
mode
==
"inference"
:
tf
.
nn
.
softmax
(
logits
,
name
=
"softmax"
)
else
:
targets
=
tf
.
reshape
(
self
.
target_seqs
,
[
-
1
])
weights
=
tf
.
to_float
(
tf
.
reshape
(
self
.
input_mask
,
[
-
1
]))
# Compute losses.
losses
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
targets
)
batch_loss
=
tf
.
div
(
tf
.
reduce_sum
(
tf
.
mul
(
losses
,
weights
)),
tf
.
reduce_sum
(
weights
),
name
=
"batch_loss"
)
tf
.
contrib
.
losses
.
add_loss
(
batch_loss
)
total_loss
=
tf
.
contrib
.
losses
.
get_total_loss
()
# Add summaries.
tf
.
scalar_summary
(
"batch_loss"
,
batch_loss
)
tf
.
scalar_summary
(
"total_loss"
,
total_loss
)
for
var
in
tf
.
trainable_variables
():
tf
.
histogram_summary
(
var
.
op
.
name
,
var
)
self
.
total_loss
=
total_loss
self
.
target_cross_entropy_losses
=
losses
# Used in evaluation.
self
.
target_cross_entropy_loss_weights
=
weights
# Used in evaluation.
def
setup_inception_initializer
(
self
):
"""Sets up the function to restore inception variables from checkpoint."""
if
self
.
mode
!=
"inference"
:
# Restore inception variables only.
saver
=
tf
.
train
.
Saver
(
self
.
inception_variables
)
def
restore_fn
(
sess
):
tf
.
logging
.
info
(
"Restoring Inception variables from checkpoint file %s"
,
self
.
config
.
inception_checkpoint_file
)
saver
.
restore
(
sess
,
self
.
config
.
inception_checkpoint_file
)
self
.
init_fn
=
restore_fn
def
setup_global_step
(
self
):
"""Sets up the global step Tensor."""
global_step
=
tf
.
Variable
(
initial_value
=
0
,
name
=
"global_step"
,
trainable
=
False
,
collections
=
[
tf
.
GraphKeys
.
GLOBAL_STEP
,
tf
.
GraphKeys
.
VARIABLES
])
self
.
global_step
=
global_step
def
setup_saver
(
self
):
"""Sets up the Saver for loading and saving model checkpoints."""
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
config
.
max_checkpoints_to_keep
,
keep_checkpoint_every_n_hours
=
self
.
config
.
keep_checkpoint_every_n_hours
)
def
build
(
self
):
"""Creates all ops for training and evaluation."""
self
.
build_inputs
()
self
.
build_image_embeddings
()
self
.
build_seq_embeddings
()
self
.
build_model
()
self
.
setup_inception_initializer
()
self
.
setup_global_step
()
self
.
setup_saver
()
im2txt/im2txt/show_and_tell_model_test.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.im2txt.show_and_tell_model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
im2txt
import
configuration
from
im2txt
import
show_and_tell_model
class
ShowAndTellModel
(
show_and_tell_model
.
ShowAndTellModel
):
"""Subclass of ShowAndTellModel without the disk I/O."""
def
build_inputs
(
self
):
if
self
.
mode
==
"inference"
:
# Inference mode doesn't read from disk, so defer to parent.
return
super
(
ShowAndTellModel
,
self
).
build_inputs
()
else
:
# Replace disk I/O with random Tensors.
self
.
images
=
tf
.
random_uniform
(
shape
=
[
self
.
config
.
batch_size
,
self
.
config
.
image_height
,
self
.
config
.
image_width
,
3
],
minval
=-
1
,
maxval
=
1
)
self
.
input_seqs
=
tf
.
random_uniform
(
[
self
.
config
.
batch_size
,
15
],
minval
=
0
,
maxval
=
self
.
config
.
vocab_size
,
dtype
=
tf
.
int64
)
self
.
target_seqs
=
tf
.
random_uniform
(
[
self
.
config
.
batch_size
,
15
],
minval
=
0
,
maxval
=
self
.
config
.
vocab_size
,
dtype
=
tf
.
int64
)
self
.
input_mask
=
tf
.
ones_like
(
self
.
input_seqs
)
class
ShowAndTellModelTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
ShowAndTellModelTest
,
self
).
setUp
()
self
.
_model_config
=
configuration
.
ModelConfig
()
def
_countModelParameters
(
self
):
"""Counts the number of parameters in the model at top level scope."""
counter
=
{}
for
v
in
tf
.
all_variables
():
name
=
v
.
op
.
name
.
split
(
"/"
)[
0
]
num_params
=
v
.
get_shape
().
num_elements
()
assert
num_params
counter
[
name
]
=
counter
.
get
(
name
,
0
)
+
num_params
return
counter
def
_checkModelParameters
(
self
):
"""Verifies the number of parameters in the model."""
param_counts
=
self
.
_countModelParameters
()
expected_param_counts
=
{
"InceptionV3"
:
21802784
,
# inception_output_size * embedding_size
"image_embedding"
:
1048576
,
# vocab_size * embedding_size
"seq_embedding"
:
6144000
,
# (embedding_size + num_lstm_units + 1) * 4 * num_lstm_units
"lstm"
:
2099200
,
# (num_lstm_units + 1) * vocab_size
"logits"
:
6156000
,
"global_step"
:
1
,
}
self
.
assertDictEqual
(
expected_param_counts
,
param_counts
)
def
_checkOutputs
(
self
,
expected_shapes
,
feed_dict
=
None
):
"""Verifies that the model produces expected outputs.
Args:
expected_shapes: A dict mapping Tensor or Tensor name to expected output
shape.
feed_dict: Values of Tensors to feed into Session.run().
"""
fetches
=
expected_shapes
.
keys
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
initialize_all_variables
())
outputs
=
sess
.
run
(
fetches
,
feed_dict
)
for
index
,
output
in
enumerate
(
outputs
):
tensor
=
fetches
[
index
]
expected
=
expected_shapes
[
tensor
]
actual
=
output
.
shape
if
expected
!=
actual
:
self
.
fail
(
"Tensor %s has shape %s (expected %s)."
%
(
tensor
,
actual
,
expected
))
def
testBuildForTraining
(
self
):
model
=
ShowAndTellModel
(
self
.
_model_config
,
mode
=
"train"
)
model
.
build
()
self
.
_checkModelParameters
()
expected_shapes
=
{
# [batch_size, image_height, image_width, 3]
model
.
images
:
(
32
,
299
,
299
,
3
),
# [batch_size, sequence_length]
model
.
input_seqs
:
(
32
,
15
),
# [batch_size, sequence_length]
model
.
target_seqs
:
(
32
,
15
),
# [batch_size, sequence_length]
model
.
input_mask
:
(
32
,
15
),
# [batch_size, embedding_size]
model
.
image_embeddings
:
(
32
,
512
),
# [batch_size, sequence_length, embedding_size]
model
.
seq_embeddings
:
(
32
,
15
,
512
),
# Scalar
model
.
total_loss
:
(),
# [batch_size * sequence_length]
model
.
target_cross_entropy_losses
:
(
480
,),
# [batch_size * sequence_length]
model
.
target_cross_entropy_loss_weights
:
(
480
,),
}
self
.
_checkOutputs
(
expected_shapes
)
def
testBuildForEval
(
self
):
model
=
ShowAndTellModel
(
self
.
_model_config
,
mode
=
"eval"
)
model
.
build
()
self
.
_checkModelParameters
()
expected_shapes
=
{
# [batch_size, image_height, image_width, 3]
model
.
images
:
(
32
,
299
,
299
,
3
),
# [batch_size, sequence_length]
model
.
input_seqs
:
(
32
,
15
),
# [batch_size, sequence_length]
model
.
target_seqs
:
(
32
,
15
),
# [batch_size, sequence_length]
model
.
input_mask
:
(
32
,
15
),
# [batch_size, embedding_size]
model
.
image_embeddings
:
(
32
,
512
),
# [batch_size, sequence_length, embedding_size]
model
.
seq_embeddings
:
(
32
,
15
,
512
),
# Scalar
model
.
total_loss
:
(),
# [batch_size * sequence_length]
model
.
target_cross_entropy_losses
:
(
480
,),
# [batch_size * sequence_length]
model
.
target_cross_entropy_loss_weights
:
(
480
,),
}
self
.
_checkOutputs
(
expected_shapes
)
def
testBuildForInference
(
self
):
model
=
ShowAndTellModel
(
self
.
_model_config
,
mode
=
"inference"
)
model
.
build
()
self
.
_checkModelParameters
()
# Test feeding an image to get the initial LSTM state.
images_feed
=
np
.
random
.
rand
(
1
,
299
,
299
,
3
)
feed_dict
=
{
model
.
images
:
images_feed
}
expected_shapes
=
{
# [batch_size, embedding_size]
model
.
image_embeddings
:
(
1
,
512
),
# [batch_size, 2 * num_lstm_units]
"lstm/initial_state:0"
:
(
1
,
1024
),
}
self
.
_checkOutputs
(
expected_shapes
,
feed_dict
)
# Test feeding a batch of inputs and LSTM states to get softmax output and
# LSTM states.
input_feed
=
np
.
random
.
randint
(
0
,
10
,
size
=
3
)
state_feed
=
np
.
random
.
rand
(
3
,
1024
)
feed_dict
=
{
"input_feed:0"
:
input_feed
,
"lstm/state_feed:0"
:
state_feed
}
expected_shapes
=
{
# [batch_size, 2 * num_lstm_units]
"lstm/state:0"
:
(
3
,
1024
),
# [batch_size, vocab_size]
"softmax:0"
:
(
3
,
12000
),
}
self
.
_checkOutputs
(
expected_shapes
,
feed_dict
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
im2txt/im2txt/train.py
0 → 100644
View file @
4f9d1024
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train the model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
im2txt
import
configuration
from
im2txt
import
show_and_tell_model
FLAGS
=
tf
.
app
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
"input_file_pattern"
,
""
,
"File pattern of sharded TFRecord input files."
)
tf
.
flags
.
DEFINE_string
(
"inception_checkpoint_file"
,
""
,
"Path to a pretrained inception_v3 model."
)
tf
.
flags
.
DEFINE_string
(
"train_dir"
,
""
,
"Directory for saving and loading model checkpoints."
)
tf
.
flags
.
DEFINE_boolean
(
"train_inception"
,
False
,
"Whether to train inception submodel variables."
)
tf
.
flags
.
DEFINE_integer
(
"number_of_steps"
,
1000000
,
"Number of training steps."
)
tf
.
flags
.
DEFINE_integer
(
"log_every_n_steps"
,
1
,
"Frequency at which loss and global step are logged."
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
main
(
unused_argv
):
assert
FLAGS
.
input_file_pattern
,
"--input_file_pattern is required"
assert
FLAGS
.
train_dir
,
"--train_dir is required"
model_config
=
configuration
.
ModelConfig
()
model_config
.
input_file_pattern
=
FLAGS
.
input_file_pattern
model_config
.
inception_checkpoint_file
=
FLAGS
.
inception_checkpoint_file
training_config
=
configuration
.
TrainingConfig
()
# Create training directory.
train_dir
=
FLAGS
.
train_dir
if
not
tf
.
gfile
.
IsDirectory
(
train_dir
):
tf
.
logging
.
info
(
"Creating training directory: %s"
,
train_dir
)
tf
.
gfile
.
MakeDirs
(
train_dir
)
# Build the TensorFlow graph.
g
=
tf
.
Graph
()
with
g
.
as_default
():
# Build the model.
model
=
show_and_tell_model
.
ShowAndTellModel
(
model_config
,
mode
=
"train"
,
train_inception
=
FLAGS
.
train_inception
)
model
.
build
()
# Set up the learning rate.
learning_rate_decay_fn
=
None
if
FLAGS
.
train_inception
:
learning_rate
=
tf
.
constant
(
training_config
.
train_inception_learning_rate
)
else
:
learning_rate
=
tf
.
constant
(
training_config
.
initial_learning_rate
)
if
training_config
.
learning_rate_decay_factor
>
0
:
num_batches_per_epoch
=
(
training_config
.
num_examples_per_epoch
/
model_config
.
batch_size
)
decay_steps
=
int
(
num_batches_per_epoch
*
training_config
.
num_epochs_per_decay
)
def
_learning_rate_decay_fn
(
learning_rate
,
global_step
):
return
tf
.
train
.
exponential_decay
(
learning_rate
,
global_step
,
decay_steps
=
decay_steps
,
decay_rate
=
training_config
.
learning_rate_decay_factor
,
staircase
=
True
)
learning_rate_decay_fn
=
_learning_rate_decay_fn
# Set up the training ops.
train_op
=
tf
.
contrib
.
layers
.
optimize_loss
(
loss
=
model
.
total_loss
,
global_step
=
model
.
global_step
,
learning_rate
=
learning_rate
,
optimizer
=
training_config
.
optimizer
,
clip_gradients
=
training_config
.
clip_gradients
,
learning_rate_decay_fn
=
learning_rate_decay_fn
)
# Run training.
tf
.
contrib
.
slim
.
learning
.
train
(
train_op
,
train_dir
,
log_every_n_steps
=
FLAGS
.
log_every_n_steps
,
graph
=
g
,
global_step
=
model
.
global_step
,
number_of_steps
=
FLAGS
.
number_of_steps
,
init_fn
=
model
.
init_fn
,
saver
=
model
.
saver
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment