Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5f99e589
Commit
5f99e589
authored
Feb 12, 2018
by
Joel Shor
Committed by
joel-shor
Feb 12, 2018
Browse files
Project import generated by Copybara.
PiperOrigin-RevId: 185395080
parent
809fc4d0
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
702 additions
and
352 deletions
+702
-352
research/gan/cifar/eval.py
research/gan/cifar/eval.py
+5
-3
research/gan/cifar/networks.py
research/gan/cifar/networks.py
+20
-8
research/gan/cifar/util.py
research/gan/cifar/util.py
+1
-1
research/gan/image_compression/launch_jobs.sh
research/gan/image_compression/launch_jobs.sh
+3
-2
research/gan/image_compression/networks_test.py
research/gan/image_compression/networks_test.py
+1
-1
research/gan/mnist/conditional_eval.py
research/gan/mnist/conditional_eval.py
+10
-4
research/gan/mnist/eval.py
research/gan/mnist/eval.py
+5
-2
research/gan/mnist/infogan_eval.py
research/gan/mnist/infogan_eval.py
+12
-7
research/gan/mnist/networks.py
research/gan/mnist/networks.py
+38
-22
research/gan/mnist/util.py
research/gan/mnist/util.py
+1
-1
research/gan/mnist_estimator/train.py
research/gan/mnist_estimator/train.py
+8
-2
research/gan/tutorial.ipynb
research/gan/tutorial.ipynb
+598
-299
No files found.
research/gan/cifar/eval.py
View file @
5f99e589
...
@@ -63,6 +63,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
...
@@ -63,6 +63,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'Number of times to run evaluation. If `None`, run '
'forever.'
)
'forever.'
)
flags
.
DEFINE_boolean
(
'write_to_disk'
,
True
,
'If `True`, run images to disk.'
)
def
main
(
_
,
run_eval_loop
=
True
):
def
main
(
_
,
run_eval_loop
=
True
):
# Fetch and generate images to run through Inception.
# Fetch and generate images to run through Inception.
...
@@ -97,7 +99,7 @@ def main(_, run_eval_loop=True):
...
@@ -97,7 +99,7 @@ def main(_, run_eval_loop=True):
# Create ops that write images to disk.
# Create ops that write images to disk.
image_write_ops
=
None
image_write_ops
=
None
if
FLAGS
.
conditional_eval
:
if
FLAGS
.
conditional_eval
and
FLAGS
.
write_to_disk
:
reshaped_imgs
=
util
.
get_image_grid
(
reshaped_imgs
=
util
.
get_image_grid
(
generated_data
,
FLAGS
.
num_images_generated
,
num_classes
,
generated_data
,
FLAGS
.
num_images_generated
,
num_classes
,
FLAGS
.
num_images_per_class
)
FLAGS
.
num_images_per_class
)
...
@@ -106,7 +108,7 @@ def main(_, run_eval_loop=True):
...
@@ -106,7 +108,7 @@ def main(_, run_eval_loop=True):
'%s/%s'
%
(
FLAGS
.
eval_dir
,
'conditional_cifar10.png'
),
'%s/%s'
%
(
FLAGS
.
eval_dir
,
'conditional_cifar10.png'
),
tf
.
image
.
encode_png
(
uint8_images
[
0
]))
tf
.
image
.
encode_png
(
uint8_images
[
0
]))
else
:
else
:
if
FLAGS
.
num_images_generated
>=
100
:
if
FLAGS
.
num_images_generated
>=
100
and
FLAGS
.
write_to_disk
:
reshaped_imgs
=
tfgan
.
eval
.
image_reshaper
(
reshaped_imgs
=
tfgan
.
eval
.
image_reshaper
(
generated_data
[:
100
],
num_cols
=
FLAGS
.
num_images_per_class
)
generated_data
[:
100
],
num_cols
=
FLAGS
.
num_images_per_class
)
uint8_images
=
data_provider
.
float_image_to_uint8
(
reshaped_imgs
)
uint8_images
=
data_provider
.
float_image_to_uint8
(
reshaped_imgs
)
...
@@ -147,7 +149,7 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
...
@@ -147,7 +149,7 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
# In order for variables to load, use the same variable scope as in the
# In order for variables to load, use the same variable scope as in the
# train job.
# train job.
with
tf
.
variable_scope
(
'Generator'
):
with
tf
.
variable_scope
(
'Generator'
):
data
=
generator_fn
(
generator_inputs
)
data
=
generator_fn
(
generator_inputs
,
is_training
=
False
)
return
data
return
data
...
...
research/gan/cifar/networks.py
View file @
5f99e589
...
@@ -32,29 +32,35 @@ def _last_conv_layer(end_points):
...
@@ -32,29 +32,35 @@ def _last_conv_layer(end_points):
return
end_points
[
conv_list
[
-
1
]]
return
end_points
[
conv_list
[
-
1
]]
def
generator
(
noise
):
def
generator
(
noise
,
is_training
=
True
):
"""Generator to produce CIFAR images.
"""Generator to produce CIFAR images.
Args:
Args:
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
does not use conditioning, this Tensor represents a noise vector of some
does not use conditioning, this Tensor represents a noise vector of some
kind that will be reshaped by the generator into CIFAR examples.
kind that will be reshaped by the generator into CIFAR examples.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A single Tensor with a batch of generated CIFAR images.
A single Tensor with a batch of generated CIFAR images.
"""
"""
images
,
_
=
dcgan
.
generator
(
noise
)
images
,
_
=
dcgan
.
generator
(
noise
,
is_training
=
is_training
)
# Make sure output lies between [-1, 1].
# Make sure output lies between [-1, 1].
return
tf
.
tanh
(
images
)
return
tf
.
tanh
(
images
)
def
conditional_generator
(
inputs
):
def
conditional_generator
(
inputs
,
is_training
=
True
):
"""Generator to produce CIFAR images.
"""Generator to produce CIFAR images.
Args:
Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator.
conditional generator.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A single Tensor with a batch of generated CIFAR images.
A single Tensor with a batch of generated CIFAR images.
...
@@ -62,13 +68,13 @@ def conditional_generator(inputs):
...
@@ -62,13 +68,13 @@ def conditional_generator(inputs):
noise
,
one_hot_labels
=
inputs
noise
,
one_hot_labels
=
inputs
noise
=
tfgan
.
features
.
condition_tensor_from_onehot
(
noise
,
one_hot_labels
)
noise
=
tfgan
.
features
.
condition_tensor_from_onehot
(
noise
,
one_hot_labels
)
images
,
_
=
dcgan
.
generator
(
noise
)
images
,
_
=
dcgan
.
generator
(
noise
,
is_training
=
is_training
)
# Make sure output lies between [-1, 1].
# Make sure output lies between [-1, 1].
return
tf
.
tanh
(
images
)
return
tf
.
tanh
(
images
)
def
discriminator
(
img
,
unused_conditioning
):
def
discriminator
(
img
,
unused_conditioning
,
is_training
=
True
):
"""Discriminator for CIFAR images.
"""Discriminator for CIFAR images.
Args:
Args:
...
@@ -79,20 +85,23 @@ def discriminator(img, unused_conditioning):
...
@@ -79,20 +85,23 @@ def discriminator(img, unused_conditioning):
would require extra `condition` information to both the generator and the
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
discriminator. Since this example is not conditional, we do not use this
argument.
argument.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
indicating high confidence that the images are real.
"""
"""
logits
,
_
=
dcgan
.
discriminator
(
img
)
logits
,
_
=
dcgan
.
discriminator
(
img
,
is_training
=
is_training
)
return
logits
return
logits
# (joelshor): This discriminator creates variables that aren't used, and
# (joelshor): This discriminator creates variables that aren't used, and
# causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# so extraneous variables aren't created.
# so extraneous variables aren't created.
def
conditional_discriminator
(
img
,
conditioning
):
def
conditional_discriminator
(
img
,
conditioning
,
is_training
=
True
):
"""Discriminator for CIFAR images.
"""Discriminator for CIFAR images.
Args:
Args:
...
@@ -100,13 +109,16 @@ def conditional_discriminator(img, conditioning):
...
@@ -100,13 +109,16 @@ def conditional_discriminator(img, conditioning):
either real or generated. It is the discriminator's goal to distinguish
either real or generated. It is the discriminator's goal to distinguish
between the two.
between the two.
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
indicating high confidence that the images are real.
"""
"""
logits
,
end_points
=
dcgan
.
discriminator
(
img
)
logits
,
end_points
=
dcgan
.
discriminator
(
img
,
is_training
=
is_training
)
# Condition the last convolution layer.
# Condition the last convolution layer.
_
,
one_hot_labels
=
conditioning
_
,
one_hot_labels
=
conditioning
...
...
research/gan/cifar/util.py
View file @
5f99e589
...
@@ -18,7 +18,7 @@ from __future__ import absolute_import
...
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
six.moves
import
xrange
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
...
...
research/gan/image_compression/launch_jobs.sh
View file @
5f99e589
...
@@ -57,8 +57,9 @@ Banner () {
...
@@ -57,8 +57,9 @@ Banner () {
echo
-e
"
${
green
}${
text
}${
nc
}
"
echo
-e
"
${
green
}${
text
}${
nc
}
"
}
}
# Download the dataset.
# Download the dataset. You will be asked for an ImageNet username and password.
bazel build
"
${
git_repo
}
/research/slim:download_and_convert_imagenet"
# To get one, register at http://www.image-net.org/.
bazel build
"
${
git_repo
}
/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet"
${
DATASET_DIR
}
"./bazel-bin/download_and_convert_imagenet"
${
DATASET_DIR
}
# Run the compression model.
# Run the compression model.
...
...
research/gan/image_compression/networks_test.py
View file @
5f99e589
...
@@ -18,8 +18,8 @@ from __future__ import absolute_import
...
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
xrange
import
networks
import
networks
...
...
research/gan/mnist/conditional_eval.py
View file @
5f99e589
...
@@ -49,6 +49,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
...
@@ -49,6 +49,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'Number of times to run evaluation. If `None`, run '
'forever.'
)
'forever.'
)
flags
.
DEFINE_boolean
(
'write_to_disk'
,
True
,
'If `True`, run images to disk.'
)
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
NUM_CLASSES
=
10
NUM_CLASSES
=
10
...
@@ -60,7 +62,8 @@ def main(_, run_eval_loop=True):
...
@@ -60,7 +62,8 @@ def main(_, run_eval_loop=True):
# Generate images.
# Generate images.
with
tf
.
variable_scope
(
'Generator'
):
# Same scope as in train job.
with
tf
.
variable_scope
(
'Generator'
):
# Same scope as in train job.
images
=
networks
.
conditional_generator
((
noise
,
one_hot_labels
))
images
=
networks
.
conditional_generator
(
(
noise
,
one_hot_labels
),
is_training
=
False
)
# Visualize images.
# Visualize images.
reshaped_img
=
tfgan
.
eval
.
image_reshaper
(
reshaped_img
=
tfgan
.
eval
.
image_reshaper
(
...
@@ -75,9 +78,12 @@ def main(_, run_eval_loop=True):
...
@@ -75,9 +78,12 @@ def main(_, run_eval_loop=True):
images
,
one_hot_labels
,
FLAGS
.
classifier_filename
))
images
,
one_hot_labels
,
FLAGS
.
classifier_filename
))
# Write images to disk.
# Write images to disk.
image_write_ops
=
tf
.
write_file
(
image_write_ops
=
None
'%s/%s'
%
(
FLAGS
.
eval_dir
,
'conditional_gan.png'
),
if
FLAGS
.
write_to_disk
:
tf
.
image
.
encode_png
(
data_provider
.
float_image_to_uint8
(
reshaped_img
[
0
])))
image_write_ops
=
tf
.
write_file
(
'%s/%s'
%
(
FLAGS
.
eval_dir
,
'conditional_gan.png'
),
tf
.
image
.
encode_png
(
data_provider
.
float_image_to_uint8
(
reshaped_img
[
0
])))
# For unit testing, use `run_eval_loop=False`.
# For unit testing, use `run_eval_loop=False`.
if
not
run_eval_loop
:
return
if
not
run_eval_loop
:
return
...
...
research/gan/mnist/eval.py
View file @
5f99e589
...
@@ -56,6 +56,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
...
@@ -56,6 +56,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'Number of times to run evaluation. If `None`, run '
'forever.'
)
'forever.'
)
flags
.
DEFINE_boolean
(
'write_to_disk'
,
True
,
'If `True`, run images to disk.'
)
def
main
(
_
,
run_eval_loop
=
True
):
def
main
(
_
,
run_eval_loop
=
True
):
# Fetch real images.
# Fetch real images.
...
@@ -72,13 +74,14 @@ def main(_, run_eval_loop=True):
...
@@ -72,13 +74,14 @@ def main(_, run_eval_loop=True):
# train job.
# train job.
with
tf
.
variable_scope
(
'Generator'
):
with
tf
.
variable_scope
(
'Generator'
):
images
=
networks
.
unconditional_generator
(
images
=
networks
.
unconditional_generator
(
tf
.
random_normal
([
FLAGS
.
num_images_generated
,
FLAGS
.
noise_dims
]))
tf
.
random_normal
([
FLAGS
.
num_images_generated
,
FLAGS
.
noise_dims
]),
is_training
=
False
)
tf
.
summary
.
scalar
(
'MNIST_Frechet_distance'
,
tf
.
summary
.
scalar
(
'MNIST_Frechet_distance'
,
util
.
mnist_frechet_distance
(
util
.
mnist_frechet_distance
(
real_images
,
images
,
FLAGS
.
classifier_filename
))
real_images
,
images
,
FLAGS
.
classifier_filename
))
tf
.
summary
.
scalar
(
'MNIST_Classifier_score'
,
tf
.
summary
.
scalar
(
'MNIST_Classifier_score'
,
util
.
mnist_score
(
images
,
FLAGS
.
classifier_filename
))
util
.
mnist_score
(
images
,
FLAGS
.
classifier_filename
))
if
FLAGS
.
num_images_generated
>=
100
:
if
FLAGS
.
num_images_generated
>=
100
and
FLAGS
.
write_to_disk
:
reshaped_images
=
tfgan
.
eval
.
image_reshaper
(
reshaped_images
=
tfgan
.
eval
.
image_reshaper
(
images
[:
100
,
...],
num_cols
=
10
)
images
[:
100
,
...],
num_cols
=
10
)
uint8_images
=
data_provider
.
float_image_to_uint8
(
reshaped_images
)
uint8_images
=
data_provider
.
float_image_to_uint8
(
reshaped_images
)
...
...
research/gan/mnist/infogan_eval.py
View file @
5f99e589
...
@@ -62,6 +62,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
...
@@ -62,6 +62,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'Number of times to run evaluation. If `None`, run '
'forever.'
)
'forever.'
)
flags
.
DEFINE_boolean
(
'write_to_disk'
,
True
,
'If `True`, run images to disk.'
)
CAT_SAMPLE_POINTS
=
np
.
arange
(
0
,
10
)
CAT_SAMPLE_POINTS
=
np
.
arange
(
0
,
10
)
CONT_SAMPLE_POINTS
=
np
.
linspace
(
-
2.0
,
2.0
,
10
)
CONT_SAMPLE_POINTS
=
np
.
linspace
(
-
2.0
,
2.0
,
10
)
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -79,7 +81,9 @@ def main(_, run_eval_loop=True):
...
@@ -79,7 +81,9 @@ def main(_, run_eval_loop=True):
# Visualize the effect of each structured noise dimension on the generated
# Visualize the effect of each structured noise dimension on the generated
# image.
# image.
generator_fn
=
lambda
x
:
networks
.
infogan_generator
(
x
,
len
(
CAT_SAMPLE_POINTS
))
def
generator_fn
(
inputs
):
return
networks
.
infogan_generator
(
inputs
,
len
(
CAT_SAMPLE_POINTS
),
is_training
=
False
)
with
tf
.
variable_scope
(
'Generator'
)
as
genscope
:
# Same scope as in training.
with
tf
.
variable_scope
(
'Generator'
)
as
genscope
:
# Same scope as in training.
categorical_images
=
generator_fn
(
display_noise1
)
categorical_images
=
generator_fn
(
display_noise1
)
reshaped_categorical_img
=
tfgan
.
eval
.
image_reshaper
(
reshaped_categorical_img
=
tfgan
.
eval
.
image_reshaper
(
...
@@ -106,12 +110,13 @@ def main(_, run_eval_loop=True):
...
@@ -106,12 +110,13 @@ def main(_, run_eval_loop=True):
# Write images to disk.
# Write images to disk.
image_write_ops
=
[]
image_write_ops
=
[]
image_write_ops
.
append
(
_get_write_image_ops
(
if
FLAGS
.
write_to_disk
:
FLAGS
.
eval_dir
,
'categorical_infogan.png'
,
reshaped_categorical_img
[
0
]))
image_write_ops
.
append
(
_get_write_image_ops
(
image_write_ops
.
append
(
_get_write_image_ops
(
FLAGS
.
eval_dir
,
'categorical_infogan.png'
,
reshaped_categorical_img
[
0
]))
FLAGS
.
eval_dir
,
'continuous1_infogan.png'
,
reshaped_continuous1_img
[
0
]))
image_write_ops
.
append
(
_get_write_image_ops
(
image_write_ops
.
append
(
_get_write_image_ops
(
FLAGS
.
eval_dir
,
'continuous1_infogan.png'
,
reshaped_continuous1_img
[
0
]))
FLAGS
.
eval_dir
,
'continuous2_infogan.png'
,
reshaped_continuous2_img
[
0
]))
image_write_ops
.
append
(
_get_write_image_ops
(
FLAGS
.
eval_dir
,
'continuous2_infogan.png'
,
reshaped_continuous2_img
[
0
]))
# For unit testing, use `run_eval_loop=False`.
# For unit testing, use `run_eval_loop=False`.
if
not
run_eval_loop
:
return
if
not
run_eval_loop
:
return
...
...
research/gan/mnist/networks.py
View file @
5f99e589
...
@@ -26,7 +26,7 @@ tfgan = tf.contrib.gan
...
@@ -26,7 +26,7 @@ tfgan = tf.contrib.gan
def
_generator_helper
(
def
_generator_helper
(
noise
,
is_conditional
,
one_hot_labels
,
weight_decay
):
noise
,
is_conditional
,
one_hot_labels
,
weight_decay
,
is_training
):
"""Core MNIST generator.
"""Core MNIST generator.
This function is reused between the different GAN modes (unconditional,
This function is reused between the different GAN modes (unconditional,
...
@@ -37,6 +37,9 @@ def _generator_helper(
...
@@ -37,6 +37,9 @@ def _generator_helper(
is_conditional: Whether to condition on labels.
is_conditional: Whether to condition on labels.
one_hot_labels: Optional labels for conditioning.
one_hot_labels: Optional labels for conditioning.
weight_decay: The value of the l2 weight decay.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A generated image in the range [-1, 1].
A generated image in the range [-1, 1].
...
@@ -45,49 +48,59 @@ def _generator_helper(
...
@@ -45,49 +48,59 @@ def _generator_helper(
[
layers
.
fully_connected
,
layers
.
conv2d_transpose
],
[
layers
.
fully_connected
,
layers
.
conv2d_transpose
],
activation_fn
=
tf
.
nn
.
relu
,
normalizer_fn
=
layers
.
batch_norm
,
activation_fn
=
tf
.
nn
.
relu
,
normalizer_fn
=
layers
.
batch_norm
,
weights_regularizer
=
layers
.
l2_regularizer
(
weight_decay
)):
weights_regularizer
=
layers
.
l2_regularizer
(
weight_decay
)):
net
=
layers
.
fully_connected
(
noise
,
1024
)
with
tf
.
contrib
.
framework
.
arg_scope
(
if
is_conditional
:
[
layers
.
batch_norm
],
is_training
=
is_training
):
net
=
tfgan
.
features
.
condition_tensor_from_onehot
(
net
,
one_hot_labels
)
net
=
layers
.
fully_connected
(
noise
,
1024
)
net
=
layers
.
fully_connected
(
net
,
7
*
7
*
128
)
if
is_conditional
:
net
=
tf
.
reshape
(
net
,
[
-
1
,
7
,
7
,
128
])
net
=
tfgan
.
features
.
condition_tensor_from_onehot
(
net
,
one_hot_labels
)
net
=
layers
.
conv2d_transpose
(
net
,
64
,
[
4
,
4
],
stride
=
2
)
net
=
layers
.
fully_connected
(
net
,
7
*
7
*
128
)
net
=
layers
.
conv2d_transpose
(
net
,
32
,
[
4
,
4
],
stride
=
2
)
net
=
tf
.
reshape
(
net
,
[
-
1
,
7
,
7
,
128
])
# Make sure that generator output is in the same range as `inputs`
net
=
layers
.
conv2d_transpose
(
net
,
64
,
[
4
,
4
],
stride
=
2
)
# ie [-1, 1].
net
=
layers
.
conv2d_transpose
(
net
,
32
,
[
4
,
4
],
stride
=
2
)
net
=
layers
.
conv2d
(
# Make sure that generator output is in the same range as `inputs`
net
,
1
,
[
4
,
4
],
normalizer_fn
=
None
,
activation_fn
=
tf
.
tanh
)
# ie [-1, 1].
net
=
layers
.
conv2d
(
return
net
net
,
1
,
[
4
,
4
],
normalizer_fn
=
None
,
activation_fn
=
tf
.
tanh
)
return
net
def
unconditional_generator
(
noise
,
weight_decay
=
2.5e-5
):
def
unconditional_generator
(
noise
,
weight_decay
=
2.5e-5
,
is_training
=
True
):
"""Generator to produce unconditional MNIST images.
"""Generator to produce unconditional MNIST images.
Args:
Args:
noise: A single Tensor representing noise.
noise: A single Tensor representing noise.
weight_decay: The value of the l2 weight decay.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A generated image in the range [-1, 1].
A generated image in the range [-1, 1].
"""
"""
return
_generator_helper
(
noise
,
False
,
None
,
weight_decay
)
return
_generator_helper
(
noise
,
False
,
None
,
weight_decay
,
is_training
)
def
conditional_generator
(
inputs
,
weight_decay
=
2.5e-5
):
def
conditional_generator
(
inputs
,
weight_decay
=
2.5e-5
,
is_training
=
True
):
"""Generator to produce MNIST images conditioned on class.
"""Generator to produce MNIST images conditioned on class.
Args:
Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels).
inputs: A 2-tuple of Tensors (noise, one_hot_labels).
weight_decay: The value of the l2 weight decay.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A generated image in the range [-1, 1].
A generated image in the range [-1, 1].
"""
"""
noise
,
one_hot_labels
=
inputs
noise
,
one_hot_labels
=
inputs
return
_generator_helper
(
noise
,
True
,
one_hot_labels
,
weight_decay
)
return
_generator_helper
(
noise
,
True
,
one_hot_labels
,
weight_decay
,
is_training
)
def
infogan_generator
(
inputs
,
categorical_dim
,
weight_decay
=
2.5e-5
):
def
infogan_generator
(
inputs
,
categorical_dim
,
weight_decay
=
2.5e-5
,
is_training
=
True
):
"""InfoGAN generator network on MNIST digits.
"""InfoGAN generator network on MNIST digits.
Based on a paper https://arxiv.org/abs/1606.03657, their code
Based on a paper https://arxiv.org/abs/1606.03657, their code
...
@@ -99,6 +112,9 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5):
...
@@ -99,6 +112,9 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5):
2D, and `inputs[1]` must be 1D. All must have the same first dimension.
2D, and `inputs[1]` must be 1D. All must have the same first dimension.
categorical_dim: Dimensions of the incompressible categorical noise.
categorical_dim: Dimensions of the incompressible categorical noise.
weight_decay: The value of the l2 weight decay.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Returns:
A generated image in the range [-1, 1].
A generated image in the range [-1, 1].
...
@@ -107,7 +123,7 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5):
...
@@ -107,7 +123,7 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5):
cat_noise_onehot
=
tf
.
one_hot
(
cat_noise
,
categorical_dim
)
cat_noise_onehot
=
tf
.
one_hot
(
cat_noise
,
categorical_dim
)
all_noise
=
tf
.
concat
(
all_noise
=
tf
.
concat
(
[
unstructured_noise
,
cat_noise_onehot
,
cont_noise
],
axis
=
1
)
[
unstructured_noise
,
cat_noise_onehot
,
cont_noise
],
axis
=
1
)
return
_generator_helper
(
all_noise
,
False
,
None
,
weight_decay
)
return
_generator_helper
(
all_noise
,
False
,
None
,
weight_decay
,
is_training
)
_leaky_relu
=
lambda
x
:
tf
.
nn
.
leaky_relu
(
x
,
alpha
=
0.01
)
_leaky_relu
=
lambda
x
:
tf
.
nn
.
leaky_relu
(
x
,
alpha
=
0.01
)
...
...
research/gan/mnist/util.py
View file @
5f99e589
...
@@ -24,7 +24,7 @@ from __future__ import print_function
...
@@ -24,7 +24,7 @@ from __future__ import print_function
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
xrange
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
ds
=
tf
.
contrib
.
distributions
ds
=
tf
.
contrib
.
distributions
...
...
research/gan/mnist_estimator/train.py
View file @
5f99e589
...
@@ -22,7 +22,7 @@ import os
...
@@ -22,7 +22,7 @@ import os
import
numpy
as
np
import
numpy
as
np
import
scipy.misc
import
scipy.misc
from
six.moves
import
xrange
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
mnist
import
data_provider
from
mnist
import
data_provider
...
@@ -66,10 +66,16 @@ def _get_predict_input_fn(batch_size, noise_dims):
...
@@ -66,10 +66,16 @@ def _get_predict_input_fn(batch_size, noise_dims):
return
predict_input_fn
return
predict_input_fn
def
_unconditional_generator
(
noise
,
mode
):
"""MNIST generator with extra argument for tf.Estimator's `mode`."""
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
return
networks
.
unconditional_generator
(
noise
,
is_training
=
is_training
)
def
main
(
_
):
def
main
(
_
):
# Initialize GANEstimator with options and hyperparameters.
# Initialize GANEstimator with options and hyperparameters.
gan_estimator
=
tfgan
.
estimator
.
GANEstimator
(
gan_estimator
=
tfgan
.
estimator
.
GANEstimator
(
generator_fn
=
networks
.
unconditional_generator
,
generator_fn
=
_
unconditional_generator
,
discriminator_fn
=
networks
.
unconditional_discriminator
,
discriminator_fn
=
networks
.
unconditional_discriminator
,
generator_loss_fn
=
tfgan
.
losses
.
wasserstein_generator_loss
,
generator_loss_fn
=
tfgan
.
losses
.
wasserstein_generator_loss
,
discriminator_loss_fn
=
tfgan
.
losses
.
wasserstein_discriminator_loss
,
discriminator_loss_fn
=
tfgan
.
losses
.
wasserstein_discriminator_loss
,
...
...
research/gan/tutorial.ipynb
View file @
5f99e589
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment