Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b51944b
Unverified
Commit
9b51944b
authored
Dec 01, 2017
by
Neal Wu
Committed by
GitHub
Dec 01, 2017
Browse files
Merge pull request #2932 from joel-shor/master
Add image to image translation example
parents
0b2bc49f
a585fc16
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
433 additions
and
0 deletions
+433
-0
research/gan/pix2pix/launch_jobs.sh
research/gan/pix2pix/launch_jobs.sh
+74
-0
research/gan/pix2pix/networks.py
research/gan/pix2pix/networks.py
+52
-0
research/gan/pix2pix/networks_test.py
research/gan/pix2pix/networks_test.py
+76
-0
research/gan/pix2pix/train.py
research/gan/pix2pix/train.py
+178
-0
research/gan/pix2pix/train_test.py
research/gan/pix2pix/train_test.py
+53
-0
No files found.
research/gan/pix2pix/launch_jobs.sh
0 → 100755
View file @
9b51944b
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Imagenet dataset.
# 2. Trains image compression model on patches from Imagenet.
# 3. Evaluates the models and writes sample images to disk.
#
# Usage:
# cd models/research/gan/image_compression
# ./launch_jobs.sh ${weight_factor} ${git_repo}
set
-e
# Weight of the adversarial loss.
weight_factor
=
$1
if
[[
"
$weight_factor
"
==
""
]]
;
then
echo
"'weight_factor' must not be empty."
exit
fi
# Location of the git repository.
git_repo
=
$2
if
[[
"
$git_repo
"
==
""
]]
;
then
echo
"'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR
=
/tmp/compression-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR
=
/tmp/compression-model/eval
# Where the dataset is saved to.
DATASET_DIR
=
/tmp/imagenet-data
export
PYTHONPATH
=
$PYTHONPATH
:
$git_repo
:
$git_repo
/research:
$git_repo
/research/slim:
$git_repo
/research/slim/nets
# A helper function for printing pretty output.
Banner
()
{
local
text
=
$1
local
green
=
'\033[0;32m'
local
nc
=
'\033[0m'
# No color.
echo
-e
"
${
green
}${
text
}${
nc
}
"
}
# Download the dataset.
bazel build
"
${
git_repo
}
/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet"
${
DATASET_DIR
}
# Run the pix2pix model.
NUM_STEPS
=
10000
MODEL_TRAIN_DIR
=
"
${
TRAIN_DIR
}
/wt
${
weight_factor
}
"
Banner
"Starting training an image compression model for
${
NUM_STEPS
}
steps..."
python
"
${
git_repo
}
/research/gan/image_compression/train.py"
\
--train_log_dir
=
${
MODEL_TRAIN_DIR
}
\
--dataset_dir
=
${
DATASET_DIR
}
\
--max_number_of_steps
=
${
NUM_STEPS
}
\
--weight_factor
=
${
weight_factor
}
\
--alsologtostderr
Banner
"Finished training pix2pix model
${
NUM_STEPS
}
steps."
research/gan/pix2pix/networks.py
0 → 100644
View file @
9b51944b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN Pix2Pix example using TFGAN."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
slim.nets
import
cyclegan
from
slim.nets
import
pix2pix
def
generator
(
input_images
):
"""Thin wrapper around CycleGAN generator to conform to the TFGAN API.
Args:
input_images: A batch of images to translate. Images should be normalized
already. Shape is [batch, height, width, channels].
Returns:
Returns generated image batch.
"""
input_images
.
shape
.
assert_has_rank
(
4
)
with
tf
.
contrib
.
framework
.
arg_scope
(
cyclegan
.
cyclegan_arg_scope
()):
output_images
,
_
=
cyclegan
.
cyclegan_generator_resnet
(
input_images
)
return
output_images
def
discriminator
(
image_batch
,
unused_conditioning
=
None
):
"""A thin wrapper around the Pix2Pix discriminator to conform to TFGAN API."""
with
tf
.
contrib
.
framework
.
arg_scope
(
pix2pix
.
pix2pix_arg_scope
()):
logits_4d
,
_
=
pix2pix
.
pix2pix_discriminator
(
image_batch
,
num_filters
=
[
64
,
128
,
256
,
512
])
logits_4d
.
shape
.
assert_has_rank
(
4
)
# Output of logits is 4D. Reshape to 2D, for TFGAN.
logits_2d
=
tf
.
contrib
.
layers
.
flatten
(
logits_4d
)
return
logits_2d
research/gan/pix2pix/networks_test.py
0 → 100644
View file @
9b51944b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.networks.networks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
google3.third_party.tensorflow_models.gan.pix2pix
import
networks
class
Pix2PixTest
(
tf
.
test
.
TestCase
):
def
test_generator_run
(
self
):
img_batch
=
tf
.
zeros
([
3
,
128
,
128
,
3
])
model_output
=
networks
.
generator
(
img_batch
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
model_output
)
def
test_generator_graph
(
self
):
for
shape
in
([
4
,
32
,
32
],
[
3
,
128
,
128
],
[
2
,
80
,
400
]):
tf
.
reset_default_graph
()
img
=
tf
.
ones
(
shape
+
[
3
])
output_imgs
=
networks
.
generator
(
img
)
self
.
assertAllEqual
(
shape
+
[
3
],
output_imgs
.
shape
.
as_list
())
def
test_generator_graph_unknown_batch_dim
(
self
):
img
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
32
,
32
,
3
])
output_imgs
=
networks
.
generator
(
img
)
self
.
assertAllEqual
([
None
,
32
,
32
,
3
],
output_imgs
.
shape
.
as_list
())
def
test_generator_invalid_input
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
'must have rank 4'
):
networks
.
generator
(
tf
.
zeros
([
28
,
28
,
3
]))
def
test_discriminator_run
(
self
):
img_batch
=
tf
.
zeros
([
3
,
70
,
70
,
3
])
disc_output
=
networks
.
discriminator
(
img_batch
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
disc_output
)
def
test_discriminator_graph
(
self
):
# Check graph construction for a number of image size/depths and batch
# sizes.
for
batch_size
,
patch_size
in
zip
([
3
,
6
],
[
70
,
128
]):
tf
.
reset_default_graph
()
img
=
tf
.
ones
([
batch_size
,
patch_size
,
patch_size
,
3
])
disc_output
=
networks
.
discriminator
(
img
)
self
.
assertEqual
(
2
,
disc_output
.
shape
.
ndims
)
self
.
assertEqual
(
batch_size
,
disc_output
.
shape
.
as_list
()[
0
])
def
test_discriminator_invalid_input
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
'Shape must be rank 4'
):
networks
.
discriminator
(
tf
.
zeros
([
28
,
28
,
3
]))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/pix2pix/train.py
0 → 100644
View file @
9b51944b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains an image-to-image translation network with an adversarial loss."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
data_provider
from
google3.third_party.tensorflow_models.gan.pix2pix
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_integer
(
'batch_size'
,
10
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'patch_size'
,
32
,
'The size of the patches to train on.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'train_log_dir'
,
'/tmp/pix2pix/'
,
'Directory where to write event logs.'
)
flags
.
DEFINE_float
(
'generator_lr'
,
0.00001
,
'The compression model learning rate.'
)
flags
.
DEFINE_float
(
'discriminator_lr'
,
0.00001
,
'The discriminator learning rate.'
)
flags
.
DEFINE_integer
(
'max_number_of_steps'
,
2000000
,
'The maximum number of gradient steps.'
)
flags
.
DEFINE_integer
(
'ps_tasks'
,
0
,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.'
)
flags
.
DEFINE_integer
(
'task'
,
0
,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.'
)
flags
.
DEFINE_float
(
'weight_factor'
,
0.0
,
'How much to weight the adversarial loss relative to pixel loss.'
)
flags
.
DEFINE_string
(
'dataset_dir'
,
None
,
'Location of data.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
if
not
tf
.
gfile
.
Exists
(
FLAGS
.
train_log_dir
):
tf
.
gfile
.
MakeDirs
(
FLAGS
.
train_log_dir
)
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
# Get real and distorted images.
with
tf
.
device
(
'/cpu:0'
),
tf
.
name_scope
(
'inputs'
):
real_images
=
data_provider
.
provide_data
(
'train'
,
FLAGS
.
batch_size
,
dataset_dir
=
FLAGS
.
dataset_dir
,
patch_size
=
FLAGS
.
patch_size
)
distorted_images
=
_distort_images
(
real_images
,
downscale_size
=
int
(
FLAGS
.
patch_size
/
2
),
upscale_size
=
FLAGS
.
patch_size
)
# Create a GANModel tuple.
gan_model
=
tfgan
.
gan_model
(
generator_fn
=
networks
.
generator
,
discriminator_fn
=
networks
.
discriminator
,
real_data
=
real_images
,
generator_inputs
=
distorted_images
)
tfgan
.
eval
.
add_image_comparison_summaries
(
gan_model
,
num_comparisons
=
3
,
display_diffs
=
True
)
tfgan
.
eval
.
add_gan_model_image_summaries
(
gan_model
,
grid_size
=
3
)
# Define the GANLoss tuple using standard library functions.
with
tf
.
name_scope
(
'losses'
):
gan_loss
=
tfgan
.
gan_loss
(
gan_model
,
generator_loss_fn
=
tfgan
.
losses
.
least_squares_generator_loss
,
discriminator_loss_fn
=
tfgan
.
losses
.
least_squares_discriminator_loss
)
# Define the standard L1 pixel loss.
l1_pixel_loss
=
tf
.
norm
(
gan_model
.
real_data
-
gan_model
.
generated_data
,
ord
=
1
)
/
FLAGS
.
patch_size
**
2
# Modify the loss tuple to include the pixel loss. Add summaries as well.
gan_loss
=
tfgan
.
losses
.
combine_adversarial_loss
(
gan_loss
,
gan_model
,
l1_pixel_loss
,
weight_factor
=
FLAGS
.
weight_factor
)
with
tf
.
name_scope
(
'train_ops'
):
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
gen_lr
,
dis_lr
=
_lr
(
FLAGS
.
generator_lr
,
FLAGS
.
discriminator_lr
)
gen_opt
,
dis_opt
=
_optimizer
(
gen_lr
,
dis_lr
)
train_ops
=
tfgan
.
gan_train_ops
(
gan_model
,
gan_loss
,
generator_optimizer
=
gen_opt
,
discriminator_optimizer
=
dis_opt
,
summarize_gradients
=
True
,
colocate_gradients_with_ops
=
True
,
aggregation_method
=
tf
.
AggregationMethod
.
EXPERIMENTAL_ACCUMULATE_N
,
transform_grads_fn
=
tf
.
contrib
.
training
.
clip_gradient_norms_fn
(
1e3
))
tf
.
summary
.
scalar
(
'generator_lr'
,
gen_lr
)
tf
.
summary
.
scalar
(
'discriminator_lr'
,
dis_lr
)
# Use GAN train step function if using adversarial loss, otherwise
# only train the generator.
train_steps
=
tfgan
.
GANTrainSteps
(
generator_train_steps
=
1
,
discriminator_train_steps
=
int
(
FLAGS
.
weight_factor
>
0
))
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
status_message
=
tf
.
string_join
(
[
'Starting train step: '
,
tf
.
as_string
(
tf
.
train
.
get_or_create_global_step
())],
name
=
'status_message'
)
if
FLAGS
.
max_number_of_steps
==
0
:
return
tfgan
.
gan_train
(
train_ops
,
FLAGS
.
train_log_dir
,
get_hooks_fn
=
tfgan
.
get_sequential_train_hooks
(
train_steps
),
hooks
=
[
tf
.
train
.
StopAtStepHook
(
num_steps
=
FLAGS
.
max_number_of_steps
),
tf
.
train
.
LoggingTensorHook
([
status_message
],
every_n_iter
=
10
)],
master
=
FLAGS
.
master
,
is_chief
=
FLAGS
.
task
==
0
)
def
_optimizer
(
gen_lr
,
dis_lr
):
kwargs
=
{
'beta1'
:
0.5
,
'beta2'
:
0.999
}
generator_opt
=
tf
.
train
.
AdamOptimizer
(
gen_lr
,
**
kwargs
)
discriminator_opt
=
tf
.
train
.
AdamOptimizer
(
dis_lr
,
**
kwargs
)
return
generator_opt
,
discriminator_opt
def
_lr
(
gen_lr_base
,
dis_lr_base
):
"""Return the generator and discriminator learning rates."""
gen_lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
gen_lr_base
,
global_step
=
tf
.
train
.
get_or_create_global_step
(),
decay_steps
=
100000
,
decay_rate
=
0.8
,
staircase
=
True
,)
dis_lr
=
dis_lr_base
return
gen_lr
,
dis_lr
def
_distort_images
(
images
,
downscale_size
,
upscale_size
):
downscaled
=
tf
.
image
.
resize_area
(
images
,
[
downscale_size
]
*
2
)
upscaled
=
tf
.
image
.
resize_area
(
downscaled
,
[
upscale_size
]
*
2
)
return
upscaled
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/gan/pix2pix/train_test.py
0 → 100644
View file @
9b51944b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for pix2pix.train."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
google3.third_party.tensorflow_models.gan.pix2pix
import
train
FLAGS
=
tf
.
flags
.
FLAGS
mock
=
tf
.
test
.
mock
class
TrainTest
(
tf
.
test
.
TestCase
):
def
_test_build_graph_helper
(
self
,
weight_factor
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
weight_factor
=
weight_factor
FLAGS
.
batch_size
=
9
FLAGS
.
patch_size
=
32
mock_imgs
=
np
.
zeros
(
[
FLAGS
.
batch_size
,
FLAGS
.
patch_size
,
FLAGS
.
patch_size
,
3
],
dtype
=
np
.
float32
)
with
mock
.
patch
.
object
(
train
,
'data_provider'
)
as
mock_data_provider
:
mock_data_provider
.
provide_data
.
return_value
=
mock_imgs
train
.
main
(
None
)
def
test_build_graph_noadversarialloss
(
self
):
self
.
_test_build_graph_helper
(
0.0
)
def
test_build_graph_adversarialloss
(
self
):
self
.
_test_build_graph_helper
(
1.0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment