Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b6907e8d
Commit
b6907e8d
authored
Nov 28, 2017
by
Joel Shor
Committed by
joel-shor
Nov 28, 2017
Browse files
Project import generated by Copybara.
PiperOrigin-RevId: 177165761
parent
220772b5
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
983 additions
and
74 deletions
+983
-74
research/gan/cifar/util_test.py
research/gan/cifar/util_test.py
+19
-20
research/gan/image_compression/data_provider.py
research/gan/image_compression/data_provider.py
+93
-0
research/gan/image_compression/data_provider_test.py
research/gan/image_compression/data_provider_test.py
+60
-0
research/gan/image_compression/eval.py
research/gan/image_compression/eval.py
+101
-0
research/gan/image_compression/eval_test.py
research/gan/image_compression/eval_test.py
+32
-0
research/gan/image_compression/launch_jobs.sh
research/gan/image_compression/launch_jobs.sh
+83
-0
research/gan/image_compression/networks.py
research/gan/image_compression/networks.py
+145
-0
research/gan/image_compression/networks_test.py
research/gan/image_compression/networks_test.py
+98
-0
research/gan/image_compression/summaries.py
research/gan/image_compression/summaries.py
+44
-0
research/gan/image_compression/testdata/labels.txt
research/gan/image_compression/testdata/labels.txt
+1
-0
research/gan/image_compression/testdata/train-00000-of-00128
research/gan/image_compression/testdata/train-00000-of-00128
+0
-0
research/gan/image_compression/testdata/validation-00000-of-00128
.../gan/image_compression/testdata/validation-00000-of-00128
+0
-0
research/gan/image_compression/train.py
research/gan/image_compression/train.py
+217
-0
research/gan/image_compression/train_test.py
research/gan/image_compression/train_test.py
+57
-0
research/gan/mnist/eval_test.py
research/gan/mnist/eval_test.py
+1
-1
research/gan/mnist/train_test.py
research/gan/mnist/train_test.py
+6
-6
research/gan/mnist_estimator/train_test.py
research/gan/mnist_estimator/train_test.py
+5
-5
research/gan/tutorial.ipynb
research/gan/tutorial.ipynb
+21
-42
No files found.
research/gan/cifar/util_test.py
View file @
b6907e8d
...
...
@@ -37,26 +37,25 @@ class UtilTest(tf.test.TestCase):
num_classes
=
3
,
num_images_per_class
=
1
)
def
test_get_inception_scores
(
self
):
# Mock `inception_score` which is expensive.
with
mock
.
patch
.
object
(
util
.
tfgan
.
eval
,
'inception_score'
)
as
mock_inception_score
:
mock_inception_score
.
return_value
=
1.0
util
.
get_inception_scores
(
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
batch_size
=
100
,
num_inception_images
=
10
)
def
test_get_frechet_inception_distance
(
self
):
# Mock `frechet_inception_distance` which is expensive.
with
mock
.
patch
.
object
(
util
.
tfgan
.
eval
,
'frechet_inception_distance'
)
as
mock_fid
:
mock_fid
.
return_value
=
1.0
util
.
get_frechet_inception_distance
(
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
batch_size
=
100
,
num_inception_images
=
10
)
# Mock `inception_score` which is expensive.
@
mock
.
patch
.
object
(
util
.
tfgan
.
eval
,
'inception_score'
,
autospec
=
True
)
def
test_get_inception_scores
(
self
,
mock_inception_score
):
mock_inception_score
.
return_value
=
1.0
util
.
get_inception_scores
(
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
batch_size
=
100
,
num_inception_images
=
10
)
# Mock `frechet_inception_distance` which is expensive.
@
mock
.
patch
.
object
(
util
.
tfgan
.
eval
,
'frechet_inception_distance'
,
autospec
=
True
)
def
test_get_frechet_inception_distance
(
self
,
mock_fid
):
mock_fid
.
return_value
=
1.0
util
.
get_frechet_inception_distance
(
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
28
,
28
,
3
]),
batch_size
=
100
,
num_inception_images
=
10
)
if
__name__
==
'__main__'
:
...
...
research/gan/image_compression/data_provider.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the compression image data."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
slim.datasets
import
dataset_factory
as
datasets
slim
=
tf
.
contrib
.
slim
def
provide_data
(
split_name
,
batch_size
,
dataset_dir
,
dataset_name
=
'imagenet'
,
num_readers
=
1
,
num_threads
=
1
,
patch_size
=
128
):
"""Provides batches of image data for compression.
Args:
split_name: Either 'train' or 'validation'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the data can be found. If `None`, use
default.
dataset_name: Name of the dataset.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
patch_size: Size of the path to extract from the image.
Returns:
images: A `Tensor` of size [batch_size, patch_size, patch_size, channels]
"""
randomize
=
split_name
==
'train'
dataset
=
datasets
.
get_dataset
(
dataset_name
,
split_name
,
dataset_dir
=
dataset_dir
)
provider
=
slim
.
dataset_data_provider
.
DatasetDataProvider
(
dataset
,
num_readers
=
num_readers
,
common_queue_capacity
=
5
*
batch_size
,
common_queue_min
=
batch_size
,
shuffle
=
randomize
)
[
image
]
=
provider
.
get
([
'image'
])
# Sample a patch of fixed size.
patch
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
patch_size
,
patch_size
)
patch
.
shape
.
assert_is_compatible_with
([
patch_size
,
patch_size
,
3
])
# Preprocess the images. Make the range lie in a strictly smaller range than
# [-1, 1], so that network outputs aren't forced to the extreme ranges.
patch
=
(
tf
.
to_float
(
patch
)
-
128.0
)
/
142.0
if
randomize
:
image_batch
=
tf
.
train
.
shuffle_batch
(
[
patch
],
batch_size
=
batch_size
,
num_threads
=
num_threads
,
capacity
=
5
*
batch_size
,
min_after_dequeue
=
batch_size
)
else
:
image_batch
=
tf
.
train
.
batch
(
[
patch
],
batch_size
=
batch_size
,
num_threads
=
1
,
# no threads so it's deterministic
capacity
=
5
*
batch_size
)
return
image_batch
def
float_image_to_uint8
(
image
):
"""Convert float image in ~[-0.9, 0.9) to [0, 255] uint8.
Args:
image: An image tensor. Values should be in [-0.9, 0.9).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image
=
(
image
*
142.0
)
+
128.0
return
tf
.
cast
(
image
,
tf
.
uint8
)
research/gan/image_compression/data_provider_test.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
import
data_provider
class
DataProviderTest
(
tf
.
test
.
TestCase
):
def
_test_data_provider_helper
(
self
,
split_name
):
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/'
)
batch_size
=
3
patch_size
=
8
images
=
data_provider
.
provide_data
(
split_name
,
batch_size
,
dataset_dir
,
patch_size
=
8
)
self
.
assertListEqual
([
batch_size
,
patch_size
,
patch_size
,
3
],
images
.
shape
.
as_list
())
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
with
tf
.
contrib
.
slim
.
queues
.
QueueRunners
(
sess
):
images_out
=
sess
.
run
(
images
)
self
.
assertEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
# Check range.
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_data_provider_train
(
self
):
self
.
_test_data_provider_helper
(
'train'
)
def
test_data_provider_validation
(
self
):
self
.
_test_data_provider_helper
(
'validation'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/eval.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained compression model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
data_provider
import
networks
import
summaries
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'checkpoint_dir'
,
'/tmp/compression/'
,
'Directory where the model was written to.'
)
flags
.
DEFINE_string
(
'eval_dir'
,
'/tmp/compression/'
,
'Directory where the results are saved to.'
)
flags
.
DEFINE_integer
(
'max_number_of_evaluations'
,
None
,
'Number of times to run evaluation. If `None`, run '
'forever.'
)
flags
.
DEFINE_string
(
'dataset_dir'
,
None
,
'Location of data.'
)
# Compression-specific flags.
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'patch_size'
,
32
,
'The size of the patches to train on.'
)
flags
.
DEFINE_integer
(
'bits_per_patch'
,
1230
,
'The number of bits to produce per patch.'
)
flags
.
DEFINE_integer
(
'model_depth'
,
64
,
'Number of filters for compression model'
)
def
main
(
_
,
run_eval_loop
=
True
):
with
tf
.
name_scope
(
'inputs'
):
images
=
data_provider
.
provide_data
(
'validation'
,
FLAGS
.
batch_size
,
dataset_dir
=
FLAGS
.
dataset_dir
,
patch_size
=
FLAGS
.
patch_size
)
# In order for variables to load, use the same variable scope as in the
# train job.
with
tf
.
variable_scope
(
'generator'
):
reconstructions
,
_
,
prebinary
=
networks
.
compression_model
(
images
,
num_bits
=
FLAGS
.
bits_per_patch
,
depth
=
FLAGS
.
model_depth
,
is_training
=
False
)
summaries
.
add_reconstruction_summaries
(
images
,
reconstructions
,
prebinary
)
# Visualize losses.
pixel_loss_per_example
=
tf
.
reduce_mean
(
tf
.
abs
(
images
-
reconstructions
),
axis
=
[
1
,
2
,
3
])
pixel_loss
=
tf
.
reduce_mean
(
pixel_loss_per_example
)
tf
.
summary
.
histogram
(
'pixel_l1_loss_hist'
,
pixel_loss_per_example
)
tf
.
summary
.
scalar
(
'pixel_l1_loss'
,
pixel_loss
)
# Create ops to write images to disk.
uint8_images
=
data_provider
.
float_image_to_uint8
(
images
)
uint8_reconstructions
=
data_provider
.
float_image_to_uint8
(
reconstructions
)
uint8_reshaped
=
summaries
.
stack_images
(
uint8_images
,
uint8_reconstructions
)
image_write_ops
=
tf
.
write_file
(
'%s/%s'
%
(
FLAGS
.
eval_dir
,
'compression.png'
),
tf
.
image
.
encode_png
(
uint8_reshaped
[
0
]))
# For unit testing, use `run_eval_loop=False`.
if
not
run_eval_loop
:
return
tf
.
contrib
.
training
.
evaluate_repeatedly
(
FLAGS
.
checkpoint_dir
,
master
=
FLAGS
.
master
,
hooks
=
[
tf
.
contrib
.
training
.
SummaryAtEndHook
(
FLAGS
.
eval_dir
),
tf
.
contrib
.
training
.
StopAfterNEvalsHook
(
1
)],
eval_ops
=
image_write_ops
,
max_number_of_evaluations
=
FLAGS
.
max_number_of_evaluations
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/gan/image_compression/eval_test.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.eval."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
eval
# pylint:disable=redefined-builtin
class
EvalTest
(
tf
.
test
.
TestCase
):
def
test_build_graph
(
self
):
eval
.
main
(
None
,
run_eval_loop
=
False
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/launch_jobs.sh
0 → 100755
View file @
b6907e8d
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Imagenet dataset.
# 2. Trains image compression model on patches from Imagenet.
# 3. Evaluates the models and writes sample images to disk.
#
# Usage:
# cd models/research/gan/image_compression
# ./launch_jobs.sh ${weight_factor} ${git_repo}
set
-e
# Weight of the adversarial loss.
weight_factor
=
$1
if
[[
"
$weight_factor
"
==
""
]]
;
then
echo
"'weight_factor' must not be empty."
exit
fi
# Location of the git repository.
git_repo
=
$2
if
[[
"
$git_repo
"
==
""
]]
;
then
echo
"'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR
=
/tmp/compression-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR
=
/tmp/compression-model/eval
# Where the dataset is saved to.
DATASET_DIR
=
/tmp/imagenet-data
export
PYTHONPATH
=
$PYTHONPATH
:
$git_repo
:
$git_repo
/research:
$git_repo
/research/slim:
$git_repo
/research/slim/nets
# A helper function for printing pretty output.
Banner
()
{
local
text
=
$1
local
green
=
'\033[0;32m'
local
nc
=
'\033[0m'
# No color.
echo
-e
"
${
green
}${
text
}${
nc
}
"
}
# Download the dataset.
"
${
git_repo
}
/research/slim/datasets/download_and_convert_imagenet.sh"
${
DATASET_DIR
}
# Run the compression model.
NUM_STEPS
=
10000
MODEL_TRAIN_DIR
=
"
${
TRAIN_DIR
}
/wt
${
weight_factor
}
"
Banner
"Starting training an image compression model for
${
NUM_STEPS
}
steps..."
python
"
${
git_repo
}
/research/gan/image_compression/train.py"
\
--train_log_dir
=
${
MODEL_TRAIN_DIR
}
\
--dataset_dir
=
${
DATASET_DIR
}
\
--max_number_of_steps
=
${
NUM_STEPS
}
\
--weight_factor
=
${
weight_factor
}
\
--alsologtostderr
Banner
"Finished training image compression model
${
NUM_STEPS
}
steps."
# Run evaluation.
MODEL_EVAL_DIR
=
"
${
TRAIN_DIR
}
/eval/wt
${
weight_factor
}
"
Banner
"Starting evaluation of image compression model..."
python
"
${
git_repo
}
/research/gan/image_compression/eval.py"
\
--checkpoint_dir
=
${
MODEL_TRAIN_DIR
}
\
--eval_dir
=
${
MODEL_EVAL_DIR
}
\
--dataset_dir
=
${
DATASET_DIR
}
\
--max_number_of_evaluation
=
1
Banner
"Finished evaluation. See
${
MODEL_EVAL_DIR
}
for output images."
research/gan/image_compression/networks.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN compression example using TFGAN."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
slim.nets
import
dcgan
from
slim.nets
import
pix2pix
def
_last_conv_layer
(
end_points
):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list
=
[
k
if
k
[:
4
]
==
'conv'
else
None
for
k
in
end_points
.
keys
()]
conv_list
.
sort
()
return
end_points
[
conv_list
[
-
1
]]
def
_encoder
(
img_batch
,
is_training
=
True
,
bits
=
64
,
depth
=
64
):
"""Maps images to internal representation.
Args:
img_batch: Stuff
is_training: Stuff
bits: Number of bits per patch.
depth: Stuff
Returns:
Real-valued 2D Tensor of size [batch_size, bits].
"""
_
,
end_points
=
dcgan
.
discriminator
(
img_batch
,
depth
=
depth
,
is_training
=
is_training
,
scope
=
'Encoder'
)
# (joelshor): Make the DCGAN convolutional layer that converts to logits
# not trainable, since it doesn't affect the encoder output.
# Get the pre-logit layer, which is the last conv.
net
=
_last_conv_layer
(
end_points
)
# Transform the features to the proper number of bits.
with
tf
.
variable_scope
(
'EncoderTransformer'
):
encoded
=
tf
.
contrib
.
layers
.
conv2d
(
net
,
bits
,
kernel_size
=
1
,
stride
=
1
,
padding
=
'VALID'
,
normalizer_fn
=
None
,
activation_fn
=
None
)
encoded
=
tf
.
squeeze
(
encoded
,
[
1
,
2
])
encoded
.
shape
.
assert_has_rank
(
2
)
# Map encoded to the range [-1, 1].
return
tf
.
nn
.
softsign
(
encoded
)
def
_binarizer
(
prebinary_codes
,
is_training
):
"""Binarize compression logits.
During training, add noise, as in https://arxiv.org/pdf/1611.01704.pdf. During
eval, map [-1, 1] -> {-1, 1}.
Args:
prebinary_codes: Floating-point tensors corresponding to pre-binary codes.
Shape is [batch, code_length].
is_training: A python bool. If True, add noise. If false, binarize.
Returns:
Binarized codes. Shape is [batch, code_length].
Raises:
ValueError: If the shape of `prebinary_codes` isn't static.
"""
if
is_training
:
# In order to train codes that can be binarized during eval, we add noise as
# in https://arxiv.org/pdf/1611.01704.pdf. Another option is to use a
# stochastic node, as in https://arxiv.org/abs/1608.05148.
noise
=
tf
.
random_uniform
(
prebinary_codes
.
shape
,
minval
=-
1.0
,
maxval
=
1.0
)
return
prebinary_codes
+
noise
else
:
return
tf
.
sign
(
prebinary_codes
)
def
_decoder
(
codes
,
final_size
,
is_training
,
depth
=
64
):
"""Compression decoder."""
decoded_img
,
_
=
dcgan
.
generator
(
codes
,
depth
=
depth
,
final_size
=
final_size
,
num_outputs
=
3
,
is_training
=
is_training
,
scope
=
'Decoder'
)
# Map output to [-1, 1].
# Use softsign instead of tanh, as per empirical results of
# http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
return
tf
.
nn
.
softsign
(
decoded_img
)
def
_validate_image_inputs
(
image_batch
):
image_batch
.
shape
.
assert_has_rank
(
4
)
image_batch
.
shape
[
1
:].
assert_is_fully_defined
()
def
compression_model
(
image_batch
,
num_bits
=
64
,
depth
=
64
,
is_training
=
True
):
"""Image compression model.
Args:
image_batch: A batch of images to compress and reconstruct. Images should
be normalized already. Shape is [batch, height, width, channels].
num_bits: Desired number of bits per image in the compressed representation.
depth: The base number of filters for the encoder and decoder networks.
is_training: A python bool. If False, run in evaluation mode.
Returns:
uncompressed images, binary codes, prebinary codes
"""
image_batch
=
tf
.
convert_to_tensor
(
image_batch
)
_validate_image_inputs
(
image_batch
)
final_size
=
image_batch
.
shape
.
as_list
()[
1
]
prebinary_codes
=
_encoder
(
image_batch
,
is_training
,
num_bits
,
depth
)
binary_codes
=
_binarizer
(
prebinary_codes
,
is_training
)
uncompressed_imgs
=
_decoder
(
binary_codes
,
final_size
,
is_training
,
depth
)
return
uncompressed_imgs
,
binary_codes
,
prebinary_codes
def
discriminator
(
image_batch
,
unused_conditioning
=
None
,
depth
=
64
):
"""A thin wrapper around the pix2pix discriminator to conform to TFGAN API."""
logits
,
_
=
pix2pix
.
pix2pix_discriminator
(
image_batch
,
num_filters
=
[
depth
,
2
*
depth
,
4
*
depth
,
8
*
depth
])
return
tf
.
layers
.
flatten
(
logits
)
research/gan/image_compression/networks_test.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.networks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
networks
class
NetworksTest
(
tf
.
test
.
TestCase
):
def
test_last_conv_layer
(
self
):
x
=
tf
.
constant
(
1.0
)
y
=
tf
.
constant
(
0.0
)
end_points
=
{
'silly'
:
y
,
'conv2'
:
y
,
'conv4'
:
x
,
'logits'
:
y
,
'conv-1'
:
y
,
}
self
.
assertEqual
(
x
,
networks
.
_last_conv_layer
(
end_points
))
def
test_generator_run
(
self
):
img_batch
=
tf
.
zeros
([
3
,
16
,
16
,
3
])
model_output
=
networks
.
compression_model
(
img_batch
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
model_output
)
def
test_generator_graph
(
self
):
for
i
,
batch_size
in
zip
(
xrange
(
3
,
7
),
xrange
(
3
,
11
,
2
)):
tf
.
reset_default_graph
()
patch_size
=
2
**
i
bits
=
2
**
i
img
=
tf
.
ones
([
batch_size
,
patch_size
,
patch_size
,
3
])
uncompressed
,
binary_codes
,
prebinary
=
networks
.
compression_model
(
img
,
bits
)
self
.
assertAllEqual
([
batch_size
,
patch_size
,
patch_size
,
3
],
uncompressed
.
shape
.
as_list
())
self
.
assertEqual
([
batch_size
,
bits
],
binary_codes
.
shape
.
as_list
())
self
.
assertEqual
([
batch_size
,
bits
],
prebinary
.
shape
.
as_list
())
def
test_generator_invalid_input
(
self
):
wrong_dim_input
=
tf
.
zeros
([
5
,
32
,
32
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape .* must have rank 4'
):
networks
.
compression_model
(
wrong_dim_input
)
not_fully_defined
=
tf
.
placeholder
(
tf
.
float32
,
[
3
,
None
,
32
,
3
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape .* is not fully defined'
):
networks
.
compression_model
(
not_fully_defined
)
def
test_discriminator_run
(
self
):
img_batch
=
tf
.
zeros
([
3
,
70
,
70
,
3
])
disc_output
=
networks
.
discriminator
(
img_batch
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
disc_output
)
def
test_discriminator_graph
(
self
):
# Check graph construction for a number of image size/depths and batch
# sizes.
for
batch_size
,
patch_size
in
zip
([
3
,
6
],
[
70
,
128
]):
tf
.
reset_default_graph
()
img
=
tf
.
ones
([
batch_size
,
patch_size
,
patch_size
,
3
])
disc_output
=
networks
.
discriminator
(
img
)
self
.
assertEqual
(
2
,
disc_output
.
shape
.
ndims
)
self
.
assertEqual
(
batch_size
,
disc_output
.
shape
[
0
])
def
test_discriminator_invalid_input
(
self
):
wrong_dim_input
=
tf
.
zeros
([
5
,
32
,
32
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape must be rank 4'
):
networks
.
discriminator
(
wrong_dim_input
)
not_fully_defined
=
tf
.
placeholder
(
tf
.
float32
,
[
3
,
None
,
32
,
3
])
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape .* is not fully defined'
):
networks
.
compression_model
(
not_fully_defined
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/summaries.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Summaries utility file to share between train and eval."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
tfgan
=
tf
.
contrib
.
gan
def
add_reconstruction_summaries
(
images
,
reconstructions
,
prebinary
,
num_imgs_to_visualize
=
8
):
"""Adds image summaries."""
reshaped_img
=
stack_images
(
images
,
reconstructions
,
num_imgs_to_visualize
)
tf
.
summary
.
image
(
'real_vs_reconstruction'
,
reshaped_img
,
max_outputs
=
1
)
if
prebinary
is
not
None
:
tf
.
summary
.
histogram
(
'prebinary_codes'
,
prebinary
)
def
stack_images
(
images
,
reconstructions
,
num_imgs_to_visualize
=
8
):
"""Stack and reshape images to see compression effects."""
to_reshape
=
(
tf
.
unstack
(
images
)[:
num_imgs_to_visualize
]
+
tf
.
unstack
(
reconstructions
)[:
num_imgs_to_visualize
])
reshaped_img
=
tfgan
.
eval
.
image_reshaper
(
to_reshape
,
num_cols
=
num_imgs_to_visualize
)
return
reshaped_img
research/gan/image_compression/testdata/labels.txt
0 → 100644
View file @
b6907e8d
research/gan/image_compression/testdata/train-00000-of-00128
0 → 100644
View file @
b6907e8d
File added
research/gan/image_compression/testdata/validation-00000-of-00128
0 → 100644
View file @
b6907e8d
File added
research/gan/image_compression/train.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint:disable=line-too-long
"""Trains an image compression network with an adversarial loss."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
data_provider
import
networks
import
summaries
tfgan
=
tf
.
contrib
.
gan
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'patch_size'
,
32
,
'The size of the patches to train on.'
)
flags
.
DEFINE_integer
(
'bits_per_patch'
,
1230
,
'The number of bits to produce per patch.'
)
flags
.
DEFINE_integer
(
'model_depth'
,
64
,
'Number of filters for compression model'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'train_log_dir'
,
'/tmp/compression/'
,
'Directory where to write event logs.'
)
flags
.
DEFINE_float
(
'generator_lr'
,
1e-5
,
'The compression model learning rate.'
)
flags
.
DEFINE_float
(
'discriminator_lr'
,
1e-6
,
'The discriminator learning rate.'
)
flags
.
DEFINE_integer
(
'max_number_of_steps'
,
2000000
,
'The maximum number of gradient steps.'
)
flags
.
DEFINE_integer
(
'ps_tasks'
,
0
,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.'
)
flags
.
DEFINE_integer
(
'task'
,
0
,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.'
)
flags
.
DEFINE_float
(
'weight_factor'
,
10000.0
,
'How much to weight the adversarial loss relative to pixel loss.'
)
flags
.
DEFINE_string
(
'dataset_dir'
,
None
,
'Location of data.'
)
def
main
(
_
):
if
not
tf
.
gfile
.
Exists
(
FLAGS
.
train_log_dir
):
tf
.
gfile
.
MakeDirs
(
FLAGS
.
train_log_dir
)
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
# Put input pipeline on CPU to reserve GPU for training.
with
tf
.
name_scope
(
'inputs'
),
tf
.
device
(
'/cpu:0'
):
images
=
data_provider
.
provide_data
(
'train'
,
FLAGS
.
batch_size
,
dataset_dir
=
FLAGS
.
dataset_dir
,
patch_size
=
FLAGS
.
patch_size
)
# Manually define a GANModel tuple. This is useful when we have custom
# code to track variables. Note that we could replace all of this with a
# call to `tfgan.gan_model`, but we don't in order to demonstrate some of
# TFGAN's flexibility.
with
tf
.
variable_scope
(
'generator'
)
as
gen_scope
:
reconstructions
,
_
,
prebinary
=
networks
.
compression_model
(
images
,
num_bits
=
FLAGS
.
bits_per_patch
,
depth
=
FLAGS
.
model_depth
)
gan_model
=
_get_gan_model
(
generator_inputs
=
images
,
generated_data
=
reconstructions
,
real_data
=
images
,
generator_scope
=
gen_scope
)
summaries
.
add_reconstruction_summaries
(
images
,
reconstructions
,
prebinary
)
tfgan
.
eval
.
add_gan_model_summaries
(
gan_model
)
# Define the GANLoss tuple using standard library functions.
with
tf
.
name_scope
(
'loss'
):
gan_loss
=
tfgan
.
gan_loss
(
gan_model
,
generator_loss_fn
=
tfgan
.
losses
.
least_squares_generator_loss
,
discriminator_loss_fn
=
tfgan
.
losses
.
least_squares_discriminator_loss
,
add_summaries
=
FLAGS
.
weight_factor
>
0
)
# Define the standard pixel loss.
l1_pixel_loss
=
tf
.
norm
(
gan_model
.
real_data
-
gan_model
.
generated_data
,
ord
=
1
)
# Modify the loss tuple to include the pixel loss. Add summaries as well.
gan_loss
=
tfgan
.
losses
.
combine_adversarial_loss
(
gan_loss
,
gan_model
,
l1_pixel_loss
,
weight_factor
=
FLAGS
.
weight_factor
)
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
with
tf
.
name_scope
(
'train_ops'
):
gen_lr
,
dis_lr
=
_lr
(
FLAGS
.
generator_lr
,
FLAGS
.
discriminator_lr
)
gen_opt
,
dis_opt
=
_optimizer
(
gen_lr
,
dis_lr
)
train_ops
=
tfgan
.
gan_train_ops
(
gan_model
,
gan_loss
,
generator_optimizer
=
gen_opt
,
discriminator_optimizer
=
dis_opt
,
summarize_gradients
=
True
,
colocate_gradients_with_ops
=
True
,
aggregation_method
=
tf
.
AggregationMethod
.
EXPERIMENTAL_ACCUMULATE_N
)
tf
.
summary
.
scalar
(
'generator_lr'
,
gen_lr
)
tf
.
summary
.
scalar
(
'discriminator_lr'
,
dis_lr
)
# Determine the number of generator vs discriminator steps.
train_steps
=
tfgan
.
GANTrainSteps
(
generator_train_steps
=
1
,
discriminator_train_steps
=
int
(
FLAGS
.
weight_factor
>
0
))
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
status_message
=
tf
.
string_join
(
[
'Starting train step: '
,
tf
.
as_string
(
tf
.
train
.
get_or_create_global_step
())],
name
=
'status_message'
)
if
FLAGS
.
max_number_of_steps
==
0
:
return
tfgan
.
gan_train
(
train_ops
,
FLAGS
.
train_log_dir
,
tfgan
.
get_sequential_train_hooks
(
train_steps
),
hooks
=
[
tf
.
train
.
StopAtStepHook
(
num_steps
=
FLAGS
.
max_number_of_steps
),
tf
.
train
.
LoggingTensorHook
([
status_message
],
every_n_iter
=
10
)],
master
=
FLAGS
.
master
,
is_chief
=
FLAGS
.
task
==
0
)
def
_optimizer
(
gen_lr
,
dis_lr
):
# First is generator optimizer, second is discriminator.
adam_kwargs
=
{
'epsilon'
:
1e-8
,
'beta1'
:
0.5
,
}
return
(
tf
.
train
.
AdamOptimizer
(
gen_lr
,
**
adam_kwargs
),
tf
.
train
.
AdamOptimizer
(
dis_lr
,
**
adam_kwargs
))
def
_lr
(
gen_lr_base
,
dis_lr_base
):
"""Return the generator and discriminator learning rates."""
gen_lr_kwargs
=
{
'decay_steps'
:
60000
,
'decay_rate'
:
0.9
,
'staircase'
:
True
,
}
gen_lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
gen_lr_base
,
global_step
=
tf
.
train
.
get_or_create_global_step
(),
**
gen_lr_kwargs
)
dis_lr
=
dis_lr_base
return
gen_lr
,
dis_lr
def
_get_gan_model
(
generator_inputs
,
generated_data
,
real_data
,
generator_scope
):
"""Manually construct and return a GANModel tuple."""
generator_vars
=
tf
.
contrib
.
framework
.
get_trainable_variables
(
generator_scope
)
discriminator_fn
=
networks
.
discriminator
with
tf
.
variable_scope
(
'discriminator'
)
as
dis_scope
:
discriminator_gen_outputs
=
discriminator_fn
(
generated_data
)
with
tf
.
variable_scope
(
dis_scope
,
reuse
=
True
):
discriminator_real_outputs
=
discriminator_fn
(
real_data
)
discriminator_vars
=
tf
.
contrib
.
framework
.
get_trainable_variables
(
dis_scope
)
# Manually construct GANModel tuple.
gan_model
=
tfgan
.
GANModel
(
generator_inputs
=
generator_inputs
,
generated_data
=
generated_data
,
generator_variables
=
generator_vars
,
generator_scope
=
generator_scope
,
generator_fn
=
None
,
# not necessary
real_data
=
real_data
,
discriminator_real_outputs
=
discriminator_real_outputs
,
discriminator_gen_outputs
=
discriminator_gen_outputs
,
discriminator_variables
=
discriminator_vars
,
discriminator_scope
=
dis_scope
,
discriminator_fn
=
discriminator_fn
)
return
gan_model
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
app
.
run
()
research/gan/image_compression/train_test.py
0 → 100644
View file @
b6907e8d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for image_compression.train."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
train
FLAGS
=
tf
.
flags
.
FLAGS
mock
=
tf
.
test
.
mock
class
TrainTest
(
tf
.
test
.
TestCase
):
def
_test_build_graph_helper
(
self
,
weight_factor
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
weight_factor
=
weight_factor
batch_size
=
3
patch_size
=
16
FLAGS
.
batch_size
=
batch_size
FLAGS
.
patch_size
=
patch_size
mock_imgs
=
np
.
zeros
([
batch_size
,
patch_size
,
patch_size
,
3
],
dtype
=
np
.
float32
)
with
mock
.
patch
.
object
(
train
,
'data_provider'
)
as
mock_data_provider
:
mock_data_provider
.
provide_data
.
return_value
=
mock_imgs
train
.
main
(
None
)
def
test_build_graph_noadversarialloss
(
self
):
self
.
_test_build_graph_helper
(
0.0
)
def
test_build_graph_adversarialloss
(
self
):
self
.
_test_build_graph_helper
(
1.0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/mnist/eval_test.py
View file @
b6907e8d
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
tensorflow
as
tf
from
google3.third_party.tensorflow_models.gan.mnist
import
eval
# pylint:disable=redefined-builtin
import
eval
# pylint:disable=redefined-builtin
class
EvalTest
(
tf
.
test
.
TestCase
):
...
...
research/gan/mnist/train_test.py
View file @
b6907e8d
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for
tfgan.examples.
mnist.train."""
"""Tests for mnist.train."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -30,7 +30,8 @@ mock = tf.test.mock
class
TrainTest
(
tf
.
test
.
TestCase
):
def
test_run_one_train_step
(
self
):
@
mock
.
patch
.
object
(
train
,
'data_provider'
,
autospec
=
True
)
def
test_run_one_train_step
(
self
,
mock_data_provider
):
FLAGS
.
max_number_of_steps
=
1
FLAGS
.
gan_type
=
'unconditional'
FLAGS
.
batch_size
=
5
...
...
@@ -42,10 +43,9 @@ class TrainTest(tf.test.TestCase):
mock_lbls
=
np
.
concatenate
(
(
np
.
ones
([
FLAGS
.
batch_size
,
1
],
dtype
=
np
.
int32
),
np
.
zeros
([
FLAGS
.
batch_size
,
9
],
dtype
=
np
.
int32
)),
axis
=
1
)
with
mock
.
patch
.
object
(
train
,
'data_provider'
)
as
mock_data_provider
:
mock_data_provider
.
provide_data
.
return_value
=
(
mock_imgs
,
mock_lbls
,
None
)
train
.
main
(
None
)
mock_data_provider
.
provide_data
.
return_value
=
(
mock_imgs
,
mock_lbls
,
None
)
train
.
main
(
None
)
def
_test_build_graph_helper
(
self
,
gan_type
):
FLAGS
.
max_number_of_steps
=
0
...
...
research/gan/mnist_estimator/train_test.py
View file @
b6907e8d
...
...
@@ -31,7 +31,8 @@ mock = tf.test.mock
class
TrainTest
(
tf
.
test
.
TestCase
):
def
test_full_flow
(
self
):
@
mock
.
patch
.
object
(
train
,
'data_provider'
,
autospec
=
True
)
def
test_full_flow
(
self
,
mock_data_provider
):
FLAGS
.
eval_dir
=
self
.
get_temp_dir
()
FLAGS
.
batch_size
=
16
FLAGS
.
max_number_of_steps
=
2
...
...
@@ -42,10 +43,9 @@ class TrainTest(tf.test.TestCase):
mock_lbls
=
np
.
concatenate
(
(
np
.
ones
([
FLAGS
.
batch_size
,
1
],
dtype
=
np
.
int32
),
np
.
zeros
([
FLAGS
.
batch_size
,
9
],
dtype
=
np
.
int32
)),
axis
=
1
)
with
mock
.
patch
.
object
(
train
,
'data_provider'
)
as
mock_data_provider
:
mock_data_provider
.
provide_data
.
return_value
=
(
mock_imgs
,
mock_lbls
,
None
)
train
.
main
(
None
)
mock_data_provider
.
provide_data
.
return_value
=
(
mock_imgs
,
mock_lbls
,
None
)
train
.
main
(
None
)
if
__name__
==
'__main__'
:
...
...
research/gan/tutorial.ipynb
View file @
b6907e8d
...
...
@@ -23,48 +23,27 @@
"metadata": {},
"source": [
"## Table of Contents\n",
"\n",
"<a href=#installation_and_setup>Installation and Setup</a>\n",
"\n",
" <a href=#download_data>Download Data</a>\n",
"\n",
"<a href=#unconditional_example>Unconditional GAN example</a>\n",
"\n",
" <a href=#unconditional_input>Input pipeline</a>\n",
"\n",
" <a href=#unconditional_model>Model</a>\n",
"\n",
" <a href=#unconditional_loss>Loss</a>\n",
"\n",
" <a href=#unconditional_train>Train and evaluation</a>\n",
"\n",
"<a href=#ganestimator_example>GANEstimator example</a>\n",
"\n",
" <a href=#ganestimator_input>Input pipeline</a>\n",
"\n",
" <a href=#ganestimator_train>Train</a>\n",
"\n",
" <a href=#ganestimator_eval>Eval</a>\n",
"\n",
"<a href=#conditional_example>Conditional GAN example</a>\n",
"\n",
" <a href=#conditional_input>Input pipeline</a>\n",
"\n",
" <a href=#conditional_model>Model</a>\n",
"\n",
" <a href=#conditional_loss>Loss</a>\n",
"\n",
" <a href=#conditional_train>Train and evaluation</a>\n",
"\n",
"<a href=#infogan_example>InfoGAN example</a>\n",
"\n",
" <a href=#infogan_input>Input pipeline</a>\n",
"\n",
" <a href=#infogan_model>Model</a>\n",
"\n",
" <a href=#infogan_loss>Loss</a>\n",
"\n",
" <a href=#infogan_train>Train and evaluation</a>"
"<a href='#installation_and_setup'>Installation and Setup</a><br>\n",
" <a href='#download_data'>Download Data</a><br>\n",
"<a href='#unconditional_example'>Unconditional GAN example</a><br>\n",
" <a href='#unconditional_input'>Input pipeline</a><br>\n",
" <a href='#unconditional_model'>Model</a><br>\n",
" <a href='#unconditional_loss'>Loss</a><br>\n",
" <a href='#unconditional_train'>Train and evaluation</a><br>\n",
"<a href='#ganestimator_example'>GANEstimator example</a><br>\n",
" <a href='#ganestimator_input'>Input pipeline</a><br>\n",
" <a href='#ganestimator_train'>Train</a><br>\n",
" <a href='#ganestimator_eval'>Eval</a><br>\n",
"<a href='#conditional_example'>Conditional GAN example</a><br>\n",
" <a href='#conditional_input'>Input pipeline</a><br>\n",
" <a href='#conditional_model'>Model</a><br>\n",
" <a href='#conditional_loss'>Loss</a><br>\n",
" <a href='#conditional_train'>Train and evaluation</a><br>\n",
"<a href='#infogan_example'>InfoGAN example</a><br>\n",
" <a href='#infogan_input'>Input pipeline</a><br>\n",
" <a href='#infogan_model'>Model</a><br>\n",
" <a href='#infogan_loss'>Loss</a><br>\n",
" <a href='#infogan_train'>Train and evaluation</a><br>"
]
},
{
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment