Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f02e6013
Commit
f02e6013
authored
Jan 04, 2018
by
Alexander Gorban
Browse files
Merge remote-tracking branch 'tensorflow/master'
parents
f5f1e12a
b719165d
Changes
157
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
720 additions
and
0 deletions
+720
-0
research/gan/cifar/util_test.py
research/gan/cifar/util_test.py
+62
-0
research/gan/g3doc/cifar_conditional_gan.png
research/gan/g3doc/cifar_conditional_gan.png
+0
-0
research/gan/g3doc/cifar_unconditional_gan.png
research/gan/g3doc/cifar_unconditional_gan.png
+0
-0
research/gan/g3doc/compression_wf0.png
research/gan/g3doc/compression_wf0.png
+0
-0
research/gan/g3doc/compression_wf10000.png
research/gan/g3doc/compression_wf10000.png
+0
-0
research/gan/g3doc/mnist_conditional_gan.png
research/gan/g3doc/mnist_conditional_gan.png
+0
-0
research/gan/g3doc/mnist_estimator_unconditional_gan.png
research/gan/g3doc/mnist_estimator_unconditional_gan.png
+0
-0
research/gan/g3doc/mnist_infogan.png
research/gan/g3doc/mnist_infogan.png
+0
-0
research/gan/g3doc/mnist_unconditional_gan.png
research/gan/g3doc/mnist_unconditional_gan.png
+0
-0
research/gan/image_compression/data_provider.py
research/gan/image_compression/data_provider.py
+93
-0
research/gan/image_compression/data_provider_test.py
research/gan/image_compression/data_provider_test.py
+60
-0
research/gan/image_compression/eval.py
research/gan/image_compression/eval.py
+101
-0
research/gan/image_compression/eval_test.py
research/gan/image_compression/eval_test.py
+32
-0
research/gan/image_compression/launch_jobs.sh
research/gan/image_compression/launch_jobs.sh
+84
-0
research/gan/image_compression/networks.py
research/gan/image_compression/networks.py
+145
-0
research/gan/image_compression/networks_test.py
research/gan/image_compression/networks_test.py
+98
-0
research/gan/image_compression/summaries.py
research/gan/image_compression/summaries.py
+44
-0
research/gan/image_compression/testdata/labels.txt
research/gan/image_compression/testdata/labels.txt
+1
-0
research/gan/image_compression/testdata/train-00000-of-00128
research/gan/image_compression/testdata/train-00000-of-00128
+0
-0
research/gan/image_compression/testdata/validation-00000-of-00128
.../gan/image_compression/testdata/validation-00000-of-00128
+0
-0
No files found.
research/gan/cifar/util_test.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.cifar.util."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
util
mock
=
tf
.
test
.
mock
class
UtilTest
(
tf
.
test
.
TestCase
):
def
test_get_generator_conditioning
(
self
):
conditioning
=
util
.
get_generator_conditioning
(
12
,
4
)
self
.
assertEqual
([
12
,
4
],
conditioning
.
shape
.
as_list
())
def
test_get_image_grid
(
self
):
util
.
get_image_grid
(
tf
.
zeros
([
6
,
28
,
28
,
1
]),
batch_size
=
6
,
num_classes
=
3
,
num_images_per_class
=
1
)
# Mock `inception_score` which is expensive.
@
mock
.
patch
.
object
(
util
.
tfgan
.
eval
,
'inception_score'
,
autospec
=
True
)
def
test_get_inception_scores
(
self
,
mock_inception_score
):
mock_inception_score
.
return_value
=
1.0
util
.
get_inception_scores
(
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
batch_size
=
100
,
num_inception_images
=
10
)
# Mock `frechet_inception_distance` which is expensive.
@
mock
.
patch
.
object
(
util
.
tfgan
.
eval
,
'frechet_inception_distance'
,
autospec
=
True
)
def
test_get_frechet_inception_distance
(
self
,
mock_fid
):
mock_fid
.
return_value
=
1.0
util
.
get_frechet_inception_distance
(
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
batch_size
=
100
,
num_inception_images
=
10
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/g3doc/cifar_conditional_gan.png
0 → 100644
View file @
f02e6013
395 KB
research/gan/g3doc/cifar_unconditional_gan.png
0 → 100644
View file @
f02e6013
830 KB
research/gan/g3doc/compression_wf0.png
0 → 100644
View file @
f02e6013
80 KB
research/gan/g3doc/compression_wf10000.png
0 → 100644
View file @
f02e6013
84.2 KB
research/gan/g3doc/mnist_conditional_gan.png
0 → 100644
View file @
f02e6013
97.3 KB
research/gan/g3doc/mnist_estimator_unconditional_gan.png
0 → 100644
View file @
f02e6013
128 KB
research/gan/g3doc/mnist_infogan.png
0 → 100644
View file @
f02e6013
53.5 KB
research/gan/g3doc/mnist_unconditional_gan.png
0 → 100644
View file @
f02e6013
89.5 KB
research/gan/image_compression/data_provider.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the compression image data."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
slim.datasets
import
dataset_factory
as
datasets
slim
=
tf
.
contrib
.
slim
def
provide_data
(
split_name
,
batch_size
,
dataset_dir
,
dataset_name
=
'imagenet'
,
num_readers
=
1
,
num_threads
=
1
,
patch_size
=
128
):
"""Provides batches of image data for compression.
Args:
split_name: Either 'train' or 'validation'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the data can be found. If `None`, use
default.
dataset_name: Name of the dataset.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
patch_size: Size of the path to extract from the image.
Returns:
images: A `Tensor` of size [batch_size, patch_size, patch_size, channels]
"""
randomize
=
split_name
==
'train'
dataset
=
datasets
.
get_dataset
(
dataset_name
,
split_name
,
dataset_dir
=
dataset_dir
)
provider
=
slim
.
dataset_data_provider
.
DatasetDataProvider
(
dataset
,
num_readers
=
num_readers
,
common_queue_capacity
=
5
*
batch_size
,
common_queue_min
=
batch_size
,
shuffle
=
randomize
)
[
image
]
=
provider
.
get
([
'image'
])
# Sample a patch of fixed size.
patch
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
patch_size
,
patch_size
)
patch
.
shape
.
assert_is_compatible_with
([
patch_size
,
patch_size
,
3
])
# Preprocess the images. Make the range lie in a strictly smaller range than
# [-1, 1], so that network outputs aren't forced to the extreme ranges.
patch
=
(
tf
.
to_float
(
patch
)
-
128.0
)
/
142.0
if
randomize
:
image_batch
=
tf
.
train
.
shuffle_batch
(
[
patch
],
batch_size
=
batch_size
,
num_threads
=
num_threads
,
capacity
=
5
*
batch_size
,
min_after_dequeue
=
batch_size
)
else
:
image_batch
=
tf
.
train
.
batch
(
[
patch
],
batch_size
=
batch_size
,
num_threads
=
1
,
# no threads so it's deterministic
capacity
=
5
*
batch_size
)
return
image_batch
def
float_image_to_uint8
(
image
):
"""Convert float image in ~[-0.9, 0.9) to [0, 255] uint8.
Args:
image: An image tensor. Values should be in [-0.9, 0.9).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image
=
(
image
*
142.0
)
+
128.0
return
tf
.
cast
(
image
,
tf
.
uint8
)
research/gan/image_compression/data_provider_test.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
import
data_provider
class
DataProviderTest
(
tf
.
test
.
TestCase
):
def
_test_data_provider_helper
(
self
,
split_name
):
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/'
)
batch_size
=
3
patch_size
=
8
images
=
data_provider
.
provide_data
(
split_name
,
batch_size
,
dataset_dir
,
patch_size
=
8
)
self
.
assertListEqual
([
batch_size
,
patch_size
,
patch_size
,
3
],
images
.
shape
.
as_list
())
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
with
tf
.
contrib
.
slim
.
queues
.
QueueRunners
(
sess
):
images_out
=
sess
.
run
(
images
)
self
.
assertEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
# Check range.
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_data_provider_train
(
self
):
self
.
_test_data_provider_helper
(
'train'
)
def
test_data_provider_validation
(
self
):
self
.
_test_data_provider_helper
(
'validation'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/eval.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained compression model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
data_provider
import
networks
import
summaries
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'checkpoint_dir'
,
'/tmp/compression/'
,
'Directory where the model was written to.'
)
flags
.
DEFINE_string
(
'eval_dir'
,
'/tmp/compression/'
,
'Directory where the results are saved to.'
)
flags
.
DEFINE_integer
(
'max_number_of_evaluations'
,
None
,
'Number of times to run evaluation. If `None`, run '
'forever.'
)
flags
.
DEFINE_string
(
'dataset_dir'
,
None
,
'Location of data.'
)
# Compression-specific flags.
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'patch_size'
,
32
,
'The size of the patches to train on.'
)
flags
.
DEFINE_integer
(
'bits_per_patch'
,
1230
,
'The number of bits to produce per patch.'
)
flags
.
DEFINE_integer
(
'model_depth'
,
64
,
'Number of filters for compression model'
)
def
main
(
_
,
run_eval_loop
=
True
):
with
tf
.
name_scope
(
'inputs'
):
images
=
data_provider
.
provide_data
(
'validation'
,
FLAGS
.
batch_size
,
dataset_dir
=
FLAGS
.
dataset_dir
,
patch_size
=
FLAGS
.
patch_size
)
# In order for variables to load, use the same variable scope as in the
# train job.
with
tf
.
variable_scope
(
'generator'
):
reconstructions
,
_
,
prebinary
=
networks
.
compression_model
(
images
,
num_bits
=
FLAGS
.
bits_per_patch
,
depth
=
FLAGS
.
model_depth
,
is_training
=
False
)
summaries
.
add_reconstruction_summaries
(
images
,
reconstructions
,
prebinary
)
# Visualize losses.
pixel_loss_per_example
=
tf
.
reduce_mean
(
tf
.
abs
(
images
-
reconstructions
),
axis
=
[
1
,
2
,
3
])
pixel_loss
=
tf
.
reduce_mean
(
pixel_loss_per_example
)
tf
.
summary
.
histogram
(
'pixel_l1_loss_hist'
,
pixel_loss_per_example
)
tf
.
summary
.
scalar
(
'pixel_l1_loss'
,
pixel_loss
)
# Create ops to write images to disk.
uint8_images
=
data_provider
.
float_image_to_uint8
(
images
)
uint8_reconstructions
=
data_provider
.
float_image_to_uint8
(
reconstructions
)
uint8_reshaped
=
summaries
.
stack_images
(
uint8_images
,
uint8_reconstructions
)
image_write_ops
=
tf
.
write_file
(
'%s/%s'
%
(
FLAGS
.
eval_dir
,
'compression.png'
),
tf
.
image
.
encode_png
(
uint8_reshaped
[
0
]))
# For unit testing, use `run_eval_loop=False`.
if
not
run_eval_loop
:
return
tf
.
contrib
.
training
.
evaluate_repeatedly
(
FLAGS
.
checkpoint_dir
,
master
=
FLAGS
.
master
,
hooks
=
[
tf
.
contrib
.
training
.
SummaryAtEndHook
(
FLAGS
.
eval_dir
),
tf
.
contrib
.
training
.
StopAfterNEvalsHook
(
1
)],
eval_ops
=
image_write_ops
,
max_number_of_evaluations
=
FLAGS
.
max_number_of_evaluations
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/gan/image_compression/eval_test.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.eval."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
eval
# pylint:disable=redefined-builtin
class
EvalTest
(
tf
.
test
.
TestCase
):
def
test_build_graph
(
self
):
eval
.
main
(
None
,
run_eval_loop
=
False
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/launch_jobs.sh
0 → 100755
View file @
f02e6013
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Imagenet dataset.
# 2. Trains image compression model on patches from Imagenet.
# 3. Evaluates the models and writes sample images to disk.
#
# Usage:
# cd models/research/gan/image_compression
# ./launch_jobs.sh ${weight_factor} ${git_repo}
set
-e
# Weight of the adversarial loss.
weight_factor
=
$1
if
[[
"
$weight_factor
"
==
""
]]
;
then
echo
"'weight_factor' must not be empty."
exit
fi
# Location of the git repository.
git_repo
=
$2
if
[[
"
$git_repo
"
==
""
]]
;
then
echo
"'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR
=
/tmp/compression-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR
=
/tmp/compression-model/eval
# Where the dataset is saved to.
DATASET_DIR
=
/tmp/imagenet-data
export
PYTHONPATH
=
$PYTHONPATH
:
$git_repo
:
$git_repo
/research:
$git_repo
/research/slim:
$git_repo
/research/slim/nets
# A helper function for printing pretty output.
Banner
()
{
local
text
=
$1
local
green
=
'\033[0;32m'
local
nc
=
'\033[0m'
# No color.
echo
-e
"
${
green
}${
text
}${
nc
}
"
}
# Download the dataset.
bazel build
"
${
git_repo
}
/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet"
${
DATASET_DIR
}
# Run the compression model.
NUM_STEPS
=
10000
MODEL_TRAIN_DIR
=
"
${
TRAIN_DIR
}
/wt
${
weight_factor
}
"
Banner
"Starting training an image compression model for
${
NUM_STEPS
}
steps..."
python
"
${
git_repo
}
/research/gan/image_compression/train.py"
\
--train_log_dir
=
${
MODEL_TRAIN_DIR
}
\
--dataset_dir
=
${
DATASET_DIR
}
\
--max_number_of_steps
=
${
NUM_STEPS
}
\
--weight_factor
=
${
weight_factor
}
\
--alsologtostderr
Banner
"Finished training image compression model
${
NUM_STEPS
}
steps."
# Run evaluation.
MODEL_EVAL_DIR
=
"
${
TRAIN_DIR
}
/eval/wt
${
weight_factor
}
"
Banner
"Starting evaluation of image compression model..."
python
"
${
git_repo
}
/research/gan/image_compression/eval.py"
\
--checkpoint_dir
=
${
MODEL_TRAIN_DIR
}
\
--eval_dir
=
${
MODEL_EVAL_DIR
}
\
--dataset_dir
=
${
DATASET_DIR
}
\
--max_number_of_evaluation
=
1
Banner
"Finished evaluation. See
${
MODEL_EVAL_DIR
}
for output images."
research/gan/image_compression/networks.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN compression example using TFGAN."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
slim.nets
import
dcgan
from
slim.nets
import
pix2pix
def
_last_conv_layer
(
end_points
):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list
=
[
k
if
k
[:
4
]
==
'conv'
else
None
for
k
in
end_points
.
keys
()]
conv_list
.
sort
()
return
end_points
[
conv_list
[
-
1
]]
def
_encoder
(
img_batch
,
is_training
=
True
,
bits
=
64
,
depth
=
64
):
"""Maps images to internal representation.
Args:
img_batch: Stuff
is_training: Stuff
bits: Number of bits per patch.
depth: Stuff
Returns:
Real-valued 2D Tensor of size [batch_size, bits].
"""
_
,
end_points
=
dcgan
.
discriminator
(
img_batch
,
depth
=
depth
,
is_training
=
is_training
,
scope
=
'Encoder'
)
# (joelshor): Make the DCGAN convolutional layer that converts to logits
# not trainable, since it doesn't affect the encoder output.
# Get the pre-logit layer, which is the last conv.
net
=
_last_conv_layer
(
end_points
)
# Transform the features to the proper number of bits.
with
tf
.
variable_scope
(
'EncoderTransformer'
):
encoded
=
tf
.
contrib
.
layers
.
conv2d
(
net
,
bits
,
kernel_size
=
1
,
stride
=
1
,
padding
=
'VALID'
,
normalizer_fn
=
None
,
activation_fn
=
None
)
encoded
=
tf
.
squeeze
(
encoded
,
[
1
,
2
])
encoded
.
shape
.
assert_has_rank
(
2
)
# Map encoded to the range [-1, 1].
return
tf
.
nn
.
softsign
(
encoded
)
def
_binarizer
(
prebinary_codes
,
is_training
):
"""Binarize compression logits.
During training, add noise, as in https://arxiv.org/pdf/1611.01704.pdf. During
eval, map [-1, 1] -> {-1, 1}.
Args:
prebinary_codes: Floating-point tensors corresponding to pre-binary codes.
Shape is [batch, code_length].
is_training: A python bool. If True, add noise. If false, binarize.
Returns:
Binarized codes. Shape is [batch, code_length].
Raises:
ValueError: If the shape of `prebinary_codes` isn't static.
"""
if
is_training
:
# In order to train codes that can be binarized during eval, we add noise as
# in https://arxiv.org/pdf/1611.01704.pdf. Another option is to use a
# stochastic node, as in https://arxiv.org/abs/1608.05148.
noise
=
tf
.
random_uniform
(
prebinary_codes
.
shape
,
minval
=-
1.0
,
maxval
=
1.0
)
return
prebinary_codes
+
noise
else
:
return
tf
.
sign
(
prebinary_codes
)
def
_decoder
(
codes
,
final_size
,
is_training
,
depth
=
64
):
"""Compression decoder."""
decoded_img
,
_
=
dcgan
.
generator
(
codes
,
depth
=
depth
,
final_size
=
final_size
,
num_outputs
=
3
,
is_training
=
is_training
,
scope
=
'Decoder'
)
# Map output to [-1, 1].
# Use softsign instead of tanh, as per empirical results of
# http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
return
tf
.
nn
.
softsign
(
decoded_img
)
def
_validate_image_inputs
(
image_batch
):
image_batch
.
shape
.
assert_has_rank
(
4
)
image_batch
.
shape
[
1
:].
assert_is_fully_defined
()
def
compression_model
(
image_batch
,
num_bits
=
64
,
depth
=
64
,
is_training
=
True
):
"""Image compression model.
Args:
image_batch: A batch of images to compress and reconstruct. Images should
be normalized already. Shape is [batch, height, width, channels].
num_bits: Desired number of bits per image in the compressed representation.
depth: The base number of filters for the encoder and decoder networks.
is_training: A python bool. If False, run in evaluation mode.
Returns:
uncompressed images, binary codes, prebinary codes
"""
image_batch
=
tf
.
convert_to_tensor
(
image_batch
)
_validate_image_inputs
(
image_batch
)
final_size
=
image_batch
.
shape
.
as_list
()[
1
]
prebinary_codes
=
_encoder
(
image_batch
,
is_training
,
num_bits
,
depth
)
binary_codes
=
_binarizer
(
prebinary_codes
,
is_training
)
uncompressed_imgs
=
_decoder
(
binary_codes
,
final_size
,
is_training
,
depth
)
return
uncompressed_imgs
,
binary_codes
,
prebinary_codes
def
discriminator
(
image_batch
,
unused_conditioning
=
None
,
depth
=
64
):
"""A thin wrapper around the pix2pix discriminator to conform to TFGAN API."""
logits
,
_
=
pix2pix
.
pix2pix_discriminator
(
image_batch
,
num_filters
=
[
depth
,
2
*
depth
,
4
*
depth
,
8
*
depth
])
return
tf
.
layers
.
flatten
(
logits
)
research/gan/image_compression/networks_test.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.networks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
networks
class
NetworksTest
(
tf
.
test
.
TestCase
):
def
test_last_conv_layer
(
self
):
x
=
tf
.
constant
(
1.0
)
y
=
tf
.
constant
(
0.0
)
end_points
=
{
'silly'
:
y
,
'conv2'
:
y
,
'conv4'
:
x
,
'logits'
:
y
,
'conv-1'
:
y
,
}
self
.
assertEqual
(
x
,
networks
.
_last_conv_layer
(
end_points
))
def
test_generator_run
(
self
):
img_batch
=
tf
.
zeros
([
3
,
16
,
16
,
3
])
model_output
=
networks
.
compression_model
(
img_batch
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
model_output
)
def
test_generator_graph
(
self
):
for
i
,
batch_size
in
zip
(
xrange
(
3
,
7
),
xrange
(
3
,
11
,
2
)):
tf
.
reset_default_graph
()
patch_size
=
2
**
i
bits
=
2
**
i
img
=
tf
.
ones
([
batch_size
,
patch_size
,
patch_size
,
3
])
uncompressed
,
binary_codes
,
prebinary
=
networks
.
compression_model
(
img
,
bits
)
self
.
assertAllEqual
([
batch_size
,
patch_size
,
patch_size
,
3
],
uncompressed
.
shape
.
as_list
())
self
.
assertEqual
([
batch_size
,
bits
],
binary_codes
.
shape
.
as_list
())
self
.
assertEqual
([
batch_size
,
bits
],
prebinary
.
shape
.
as_list
())
def
test_generator_invalid_input
(
self
):
wrong_dim_input
=
tf
.
zeros
([
5
,
32
,
32
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape .* must have rank 4'
):
networks
.
compression_model
(
wrong_dim_input
)
not_fully_defined
=
tf
.
placeholder
(
tf
.
float32
,
[
3
,
None
,
32
,
3
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape .* is not fully defined'
):
networks
.
compression_model
(
not_fully_defined
)
def
test_discriminator_run
(
self
):
img_batch
=
tf
.
zeros
([
3
,
70
,
70
,
3
])
disc_output
=
networks
.
discriminator
(
img_batch
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
disc_output
)
def
test_discriminator_graph
(
self
):
# Check graph construction for a number of image size/depths and batch
# sizes.
for
batch_size
,
patch_size
in
zip
([
3
,
6
],
[
70
,
128
]):
tf
.
reset_default_graph
()
img
=
tf
.
ones
([
batch_size
,
patch_size
,
patch_size
,
3
])
disc_output
=
networks
.
discriminator
(
img
)
self
.
assertEqual
(
2
,
disc_output
.
shape
.
ndims
)
self
.
assertEqual
(
batch_size
,
disc_output
.
shape
[
0
])
def
test_discriminator_invalid_input
(
self
):
wrong_dim_input
=
tf
.
zeros
([
5
,
32
,
32
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape must be rank 4'
):
networks
.
discriminator
(
wrong_dim_input
)
not_fully_defined
=
tf
.
placeholder
(
tf
.
float32
,
[
3
,
None
,
32
,
3
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape .* is not fully defined'
):
networks
.
compression_model
(
not_fully_defined
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/summaries.py
0 → 100644
View file @
f02e6013
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Summaries utility file to share between train and eval."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
tfgan
=
tf
.
contrib
.
gan
def
add_reconstruction_summaries
(
images
,
reconstructions
,
prebinary
,
num_imgs_to_visualize
=
8
):
"""Adds image summaries."""
reshaped_img
=
stack_images
(
images
,
reconstructions
,
num_imgs_to_visualize
)
tf
.
summary
.
image
(
'real_vs_reconstruction'
,
reshaped_img
,
max_outputs
=
1
)
if
prebinary
is
not
None
:
tf
.
summary
.
histogram
(
'prebinary_codes'
,
prebinary
)
def
stack_images
(
images
,
reconstructions
,
num_imgs_to_visualize
=
8
):
"""Stack and reshape images to see compression effects."""
to_reshape
=
(
tf
.
unstack
(
images
)[:
num_imgs_to_visualize
]
+
tf
.
unstack
(
reconstructions
)[:
num_imgs_to_visualize
])
reshaped_img
=
tfgan
.
eval
.
image_reshaper
(
to_reshape
,
num_cols
=
num_imgs_to_visualize
)
return
reshaped_img
research/gan/image_compression/testdata/labels.txt
0 → 100644
View file @
f02e6013
research/gan/image_compression/testdata/train-00000-of-00128
0 → 100644
View file @
f02e6013
File added
research/gan/image_compression/testdata/validation-00000-of-00128
0 → 100644
View file @
f02e6013
File added
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment