Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
54a5a577
Commit
54a5a577
authored
Sep 29, 2018
by
Joel Shor
Committed by
Joel Shor
Sep 29, 2018
Browse files
Project import generated by Copybara.
PiperOrigin-RevId: 215004158
parent
4f7074f6
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1543 additions
and
80 deletions
+1543
-80
research/gan/cifar/train_test.py
research/gan/cifar/train_test.py
+1
-1
research/gan/cyclegan/data_provider.py
research/gan/cyclegan/data_provider.py
+61
-26
research/gan/cyclegan/data_provider_test.py
research/gan/cyclegan/data_provider_test.py
+41
-15
research/gan/cyclegan/inference_demo_test.py
research/gan/cyclegan/inference_demo_test.py
+6
-1
research/gan/cyclegan/train.py
research/gan/cyclegan/train.py
+4
-1
research/gan/cyclegan/train_test.py
research/gan/cyclegan/train_test.py
+4
-3
research/gan/progressive_gan/train.py
research/gan/progressive_gan/train.py
+41
-16
research/gan/progressive_gan/train_main.py
research/gan/progressive_gan/train_main.py
+10
-6
research/gan/progressive_gan/train_test.py
research/gan/progressive_gan/train_test.py
+20
-11
research/gan/stargan/data_provider.py
research/gan/stargan/data_provider.py
+33
-0
research/gan/stargan/data_provider_test.py
research/gan/stargan/data_provider_test.py
+42
-0
research/gan/stargan/layers.py
research/gan/stargan/layers.py
+392
-0
research/gan/stargan/layers_test.py
research/gan/stargan/layers_test.py
+137
-0
research/gan/stargan/network.py
research/gan/stargan/network.py
+102
-0
research/gan/stargan/network_test.py
research/gan/stargan/network_test.py
+61
-0
research/gan/stargan/ops.py
research/gan/stargan/ops.py
+106
-0
research/gan/stargan/ops_test.py
research/gan/stargan/ops_test.py
+113
-0
research/gan/stargan/train.py
research/gan/stargan/train.py
+228
-0
research/gan/stargan/train_test.py
research/gan/stargan/train_test.py
+141
-0
research/gan/stargan_estimator/data/celeba_test_split_images.npy
...h/gan/stargan_estimator/data/celeba_test_split_images.npy
+0
-0
No files found.
research/gan/cifar/train_test.py
View file @
54a5a577
...
@@ -35,7 +35,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -35,7 +35,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
(
'Unconditional'
,
False
,
False
),
(
'Unconditional'
,
False
,
False
),
(
'Conditional'
,
True
,
False
),
(
'Conditional'
,
True
,
False
),
(
'SyncReplicas'
,
False
,
True
))
(
'SyncReplicas'
,
False
,
True
))
def
test_build_graph
_helper
(
self
,
conditional
,
use_sync_replicas
):
def
test_build_graph
(
self
,
conditional
,
use_sync_replicas
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
conditional
=
conditional
FLAGS
.
conditional
=
conditional
FLAGS
.
use_sync_replicas
=
use_sync_replicas
FLAGS
.
use_sync_replicas
=
use_sync_replicas
...
...
research/gan/cyclegan/data_provider.py
View file @
54a5a577
...
@@ -80,36 +80,34 @@ def _provide_custom_dataset(image_file_pattern,
...
@@ -80,36 +80,34 @@ def _provide_custom_dataset(image_file_pattern,
image_file_pattern: A string of glob pattern of image files.
image_file_pattern: A string of glob pattern of image files.
batch_size: The number of images in each batch.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of
prefetch
ing threads. Defaults to 1.
num_threads: Number of
mapp
ing threads. Defaults to 1.
patch_size: Size of the path to extract from the image. Defaults to 128.
patch_size: Size of the path to extract from the image. Defaults to 128.
Returns:
Returns:
A float `Tensor` of shape [batch_size, patch_size, patch_size, 3]
A tf.data.Dataset with Tensors of shape
representing a batch of images.
[batch_size, patch_size, patch_size, 3] representing a batch of images.
"""
filename_queue
=
tf
.
train
.
string_input_producer
(
tf
.
train
.
match_filenames_once
(
image_file_pattern
),
shuffle
=
shuffle
,
capacity
=
5
*
batch_size
)
image_reader
=
tf
.
WholeFileReader
()
_
,
image_bytes
=
image_reader
.
read
(
filename_queue
)
Raises:
image
=
tf
.
image
.
decode_image
(
image_bytes
)
ValueError: If no files match `image_file_pattern`.
image_patch
=
full_image_to_patch
(
image
,
patch_size
)
"""
if
not
tf
.
gfile
.
Glob
(
image_file_pattern
):
raise
ValueError
(
'No file patterns found.'
)
filenames_ds
=
tf
.
data
.
Dataset
.
list_files
(
image_file_pattern
)
bytes_ds
=
filenames_ds
.
map
(
tf
.
io
.
read_file
,
num_parallel_calls
=
num_threads
)
images_ds
=
bytes_ds
.
map
(
tf
.
image
.
decode_image
,
num_parallel_calls
=
num_threads
)
patches_ds
=
images_ds
.
map
(
lambda
img
:
full_image_to_patch
(
img
,
patch_size
),
num_parallel_calls
=
num_threads
)
patches_ds
=
patches_ds
.
repeat
()
if
shuffle
:
if
shuffle
:
return
tf
.
train
.
shuffle_batch
(
patches_ds
=
patches_ds
.
shuffle
(
5
*
batch_size
)
[
image_patch
],
batch_size
=
batch_size
,
patches_ds
=
patches_ds
.
prefetch
(
5
*
batch_size
)
num_threads
=
num_threads
,
patches_ds
=
patches_ds
.
batch
(
batch_size
)
capacity
=
5
*
batch_size
,
min_after_dequeue
=
batch_size
)
return
patches_ds
else
:
return
tf
.
train
.
batch
(
[
image_patch
],
batch_size
=
batch_size
,
num_threads
=
1
,
# no threads so it's deterministic
capacity
=
5
*
batch_size
)
def
provide_custom_datasets
(
image_file_patterns
,
def
provide_custom_datasets
(
image_file_patterns
,
...
@@ -127,8 +125,8 @@ def provide_custom_datasets(image_file_patterns,
...
@@ -127,8 +125,8 @@ def provide_custom_datasets(image_file_patterns,
patch_size: Size of the patch to extract from the image. Defaults to 128.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
Returns:
A list of
float `Tensor`s with the same size of
`image_file_patterns`.
A list of
tf.data.Datasets the same number as
`image_file_patterns`.
Each
Each
of the `Tensor` in the list has a shape of
of the
datasets have
`Tensor`
's
in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images.
[batch_size, patch_size, patch_size, 3] representing a batch of images.
Raises:
Raises:
...
@@ -147,4 +145,41 @@ def provide_custom_datasets(image_file_patterns,
...
@@ -147,4 +145,41 @@ def provide_custom_datasets(image_file_patterns,
shuffle
=
shuffle
,
shuffle
=
shuffle
,
num_threads
=
num_threads
,
num_threads
=
num_threads
,
patch_size
=
patch_size
))
patch_size
=
patch_size
))
return
custom_datasets
return
custom_datasets
def
provide_custom_data
(
image_file_patterns
,
batch_size
,
shuffle
=
True
,
num_threads
=
1
,
patch_size
=
128
):
"""Provides multiple batches of custom image data.
Args:
image_file_patterns: A list of glob patterns of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of float `Tensor`s with the same size of `image_file_patterns`. Each
of the `Tensor` in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images. As a
side effect, the tf.Dataset initializer is added to the
tf.GraphKeys.TABLE_INITIALIZERS collection.
Raises:
ValueError: If image_file_patterns is not a list or tuple.
"""
datasets
=
provide_custom_datasets
(
image_file_patterns
,
batch_size
,
shuffle
,
num_threads
,
patch_size
)
tensors
=
[]
for
ds
in
datasets
:
iterator
=
ds
.
make_initializable_iterator
()
tf
.
add_to_collection
(
tf
.
GraphKeys
.
TABLE_INITIALIZERS
,
iterator
.
initializer
)
tensors
.
append
(
iterator
.
get_next
())
return
tensors
research/gan/cyclegan/data_provider_test.py
View file @
54a5a577
...
@@ -62,41 +62,67 @@ class DataProviderTest(tf.test.TestCase):
...
@@ -62,41 +62,67 @@ class DataProviderTest(tf.test.TestCase):
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
batch_size
=
3
patch_size
=
8
patch_size
=
8
images
=
data_provider
.
_provide_custom_dataset
(
images
_ds
=
data_provider
.
_provide_custom_dataset
(
file_pattern
,
batch_size
=
batch_size
,
patch_size
=
patch_size
)
file_pattern
,
batch_size
=
batch_size
,
patch_size
=
patch_size
)
self
.
assertListEqual
([
batch_siz
e
,
patch_size
,
patch_size
,
3
],
self
.
assertListEqual
([
Non
e
,
patch_size
,
patch_size
,
3
],
images
.
shape
.
as_list
())
images
_ds
.
output_
shape
s
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
images
.
d
type
)
self
.
assertEqual
(
tf
.
float32
,
images
_ds
.
output_
type
s
)
iterator
=
images_ds
.
make_initializable_iterator
()
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
with
tf
.
contrib
.
slim
.
queues
.
QueueRunners
(
sess
):
sess
.
run
(
iterator
.
initializer
)
images_out
=
sess
.
run
(
images
)
images_out
=
sess
.
run
(
iterator
.
get_next
())
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_custom_datasets_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
patch_size
=
8
images_ds_list
=
data_provider
.
provide_custom_datasets
(
[
file_pattern
,
file_pattern
],
batch_size
=
batch_size
,
patch_size
=
patch_size
)
for
images_ds
in
images_ds_list
:
self
.
assertListEqual
([
None
,
patch_size
,
patch_size
,
3
],
images_ds
.
output_shapes
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
images_ds
.
output_types
)
iterators
=
[
x
.
make_initializable_iterator
()
for
x
in
images_ds_list
]
initialiers
=
[
x
.
initializer
for
x
in
iterators
]
img_tensors
=
[
x
.
get_next
()
for
x
in
iterators
]
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
initialiers
)
images_out_list
=
sess
.
run
(
img_tensors
)
for
images_out
in
images_out_list
:
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
images_out
.
shape
)
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_custom_data
sets
_provider
(
self
):
def
test_custom_data_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
batch_size
=
3
patch_size
=
8
patch_size
=
8
images_list
=
data_provider
.
provide_custom_data
sets
(
images_list
=
data_provider
.
provide_custom_data
(
[
file_pattern
,
file_pattern
],
[
file_pattern
,
file_pattern
],
batch_size
=
batch_size
,
batch_size
=
batch_size
,
patch_size
=
patch_size
)
patch_size
=
patch_size
)
for
images
in
images_list
:
for
images
in
images_list
:
self
.
assertListEqual
([
batch_siz
e
,
patch_size
,
patch_size
,
3
],
self
.
assertListEqual
([
Non
e
,
patch_size
,
patch_size
,
3
],
images
.
shape
.
as_list
())
images
.
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
images
.
dtype
)
self
.
assertEqual
(
tf
.
float32
,
images
.
dtype
)
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
with
tf
.
contrib
.
slim
.
queues
.
QueueRunners
(
sess
):
sess
.
run
(
tf
.
tables_initializer
())
images_out_list
=
sess
.
run
(
images_list
)
images_out_list
=
sess
.
run
(
images_list
)
for
images_out
in
images_out_list
:
for
images_out
in
images_out_list
:
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
images_out
.
shape
)
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
research/gan/cyclegan/inference_demo_test.py
View file @
54a5a577
...
@@ -35,7 +35,10 @@ class InferenceDemoTest(tf.test.TestCase):
...
@@ -35,7 +35,10 @@ class InferenceDemoTest(tf.test.TestCase):
self
.
_geny_dir
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'geny'
)
self
.
_geny_dir
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'geny'
)
@
mock
.
patch
.
object
(
tfgan
,
'gan_train'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
tfgan
,
'gan_train'
,
autospec
=
True
)
def
testTrainingAndInferenceGraphsAreCompatible
(
self
,
unused_mock_gan_train
):
@
mock
.
patch
.
object
(
train
.
data_provider
,
'provide_custom_data'
,
autospec
=
True
)
def
testTrainingAndInferenceGraphsAreCompatible
(
self
,
mock_provide_custom_data
,
unused_mock_gan_train
):
# Training and inference graphs can get out of sync if changes are made
# Training and inference graphs can get out of sync if changes are made
# to one but not the other. This test will keep them in sync.
# to one but not the other. This test will keep them in sync.
...
@@ -52,6 +55,8 @@ class InferenceDemoTest(tf.test.TestCase):
...
@@ -52,6 +55,8 @@ class InferenceDemoTest(tf.test.TestCase):
FLAGS
.
task
=
0
FLAGS
.
task
=
0
FLAGS
.
cycle_consistency_loss_weight
=
2.0
FLAGS
.
cycle_consistency_loss_weight
=
2.0
FLAGS
.
max_number_of_steps
=
1
FLAGS
.
max_number_of_steps
=
1
mock_provide_custom_data
.
return_value
=
(
tf
.
zeros
([
3
,
4
,
4
,
3
,]),
tf
.
zeros
([
3
,
4
,
4
,
3
]))
train
.
main
(
None
)
train
.
main
(
None
)
init_op
=
tf
.
global_variables_initializer
()
init_op
=
tf
.
global_variables_initializer
()
train_sess
.
run
(
init_op
)
train_sess
.
run
(
init_op
)
...
...
research/gan/cyclegan/train.py
View file @
54a5a577
...
@@ -169,10 +169,13 @@ def main(_):
...
@@ -169,10 +169,13 @@ def main(_):
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
with
tf
.
name_scope
(
'inputs'
):
with
tf
.
name_scope
(
'inputs'
):
images_x
,
images_y
=
data_provider
.
provide_custom_data
sets
(
images_x
,
images_y
=
data_provider
.
provide_custom_data
(
[
FLAGS
.
image_set_x_file_pattern
,
FLAGS
.
image_set_y_file_pattern
],
[
FLAGS
.
image_set_x_file_pattern
,
FLAGS
.
image_set_y_file_pattern
],
batch_size
=
FLAGS
.
batch_size
,
batch_size
=
FLAGS
.
batch_size
,
patch_size
=
FLAGS
.
patch_size
)
patch_size
=
FLAGS
.
patch_size
)
# Set batch size for summaries.
images_x
.
set_shape
([
FLAGS
.
batch_size
,
None
,
None
,
None
])
images_y
.
set_shape
([
FLAGS
.
batch_size
,
None
,
None
,
None
])
# Define CycleGAN model.
# Define CycleGAN model.
cyclegan_model
=
_define_model
(
images_x
,
images_y
)
cyclegan_model
=
_define_model
(
images_x
,
images_y
)
...
...
research/gan/cyclegan/train_test.py
View file @
54a5a577
...
@@ -128,11 +128,12 @@ class TrainTest(tf.test.TestCase):
...
@@ -128,11 +128,12 @@ class TrainTest(tf.test.TestCase):
FLAGS
.
cycle_consistency_loss_weight
=
2.0
FLAGS
.
cycle_consistency_loss_weight
=
2.0
FLAGS
.
max_number_of_steps
=
1
FLAGS
.
max_number_of_steps
=
1
mock_data_provider
.
provide_custom_datasets
.
return_value
=
(
tf
.
zeros
(
mock_data_provider
.
provide_custom_data
.
return_value
=
(
[
1
,
2
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
1
,
2
],
dtype
=
tf
.
float32
))
tf
.
zeros
([
3
,
2
,
2
,
3
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
3
,
2
,
2
,
3
],
dtype
=
tf
.
float32
))
train
.
main
(
None
)
train
.
main
(
None
)
mock_data_provider
.
provide_custom_data
sets
.
assert_called_once_with
(
mock_data_provider
.
provide_custom_data
.
assert_called_once_with
(
[
'/tmp/x/*.jpg'
,
'/tmp/y/*.jpg'
],
batch_size
=
3
,
patch_size
=
8
)
[
'/tmp/x/*.jpg'
,
'/tmp/y/*.jpg'
],
batch_size
=
3
,
patch_size
=
8
)
mock_define_model
.
assert_called_once_with
(
mock
.
ANY
,
mock
.
ANY
)
mock_define_model
.
assert_called_once_with
(
mock
.
ANY
,
mock
.
ANY
)
mock_cyclegan_loss
.
assert_called_once_with
(
mock_cyclegan_loss
.
assert_called_once_with
(
...
...
research/gan/progressive_gan/train.py
View file @
54a5a577
...
@@ -75,6 +75,34 @@ def get_total_num_stages(**kwargs):
...
@@ -75,6 +75,34 @@ def get_total_num_stages(**kwargs):
return
2
*
kwargs
[
'num_resolutions'
]
-
1
return
2
*
kwargs
[
'num_resolutions'
]
-
1
def
get_batch_size
(
stage_id
,
**
kwargs
):
"""Returns batch size for each stage.
It is expected that `len(batch_size_schedule) == num_resolutions`. Each stage
corresponds to a resolution and hence a batch size. However if
`len(batch_size_schedule) < num_resolutions`, pad `batch_size_schedule` in the
beginning with the first batch size.
Args:
stage_id: An integer of training stage index.
**kwargs: A dictionary of
'batch_size_schedule': A list of integer, each element is the batch size
for the current training image resolution.
'num_resolutions': An integer of number of progressive resolutions.
Returns:
An integer batch size for the `stage_id`.
"""
batch_size_schedule
=
kwargs
[
'batch_size_schedule'
]
num_resolutions
=
kwargs
[
'num_resolutions'
]
if
len
(
batch_size_schedule
)
<
num_resolutions
:
batch_size_schedule
=
(
[
batch_size_schedule
[
0
]]
*
(
num_resolutions
-
len
(
batch_size_schedule
))
+
batch_size_schedule
)
return
int
(
batch_size_schedule
[(
stage_id
+
1
)
//
2
])
def
get_stage_info
(
stage_id
,
**
kwargs
):
def
get_stage_info
(
stage_id
,
**
kwargs
):
"""Returns information for a training stage.
"""Returns information for a training stage.
...
@@ -228,14 +256,14 @@ def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
...
@@ -228,14 +256,14 @@ def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
return
gan_train_ops
,
generator_vars_to_restore
return
gan_train_ops
,
generator_vars_to_restore
def
build_model
(
stage_id
,
real_images
,
**
kwargs
):
def
build_model
(
stage_id
,
batch_size
,
real_images
,
**
kwargs
):
"""Builds progressive GAN model.
"""Builds progressive GAN model.
Args:
Args:
stage_id: An integer of training stage index.
stage_id: An integer of training stage index.
batch_size: Number of training images in each minibatch.
real_images: A 4D `Tensor` of NHWC format.
real_images: A 4D `Tensor` of NHWC format.
**kwargs: A dictionary of
**kwargs: A dictionary of
'batch_size': Number of training images in each minibatch.
'start_height': An integer of start image height.
'start_height': An integer of start image height.
'start_width': An integer of start image width.
'start_width': An integer of start image width.
'scale_base': An integer of resolution multiplier.
'scale_base': An integer of resolution multiplier.
...
@@ -267,15 +295,14 @@ def build_model(stage_id, real_images, **kwargs):
...
@@ -267,15 +295,14 @@ def build_model(stage_id, real_images, **kwargs):
Returns:
Returns:
An inernal object that wraps all information about the model.
An inernal object that wraps all information about the model.
"""
"""
batch_size
=
kwargs
[
'batch_size'
]
kernel_size
=
kwargs
[
'kernel_size'
]
kernel_size
=
kwargs
[
'kernel_size'
]
colors
=
kwargs
[
'colors'
]
colors
=
kwargs
[
'colors'
]
resolution_schedule
=
make_resolution_schedule
(
**
kwargs
)
resolution_schedule
=
make_resolution_schedule
(
**
kwargs
)
num_blocks
,
num_images
=
get_stage_info
(
stage_id
,
**
kwargs
)
num_blocks
,
num_images
=
get_stage_info
(
stage_id
,
**
kwargs
)
global_step
=
tf
.
train
.
get_or_create_global_step
()
current_image_id
=
tf
.
train
.
get_or_create_global_step
()
current_image_id
=
global_step
*
batch_size
current_image_id
_inc_op
=
current_image_id
.
assign_add
(
batch_size
)
tf
.
summary
.
scalar
(
'current_image_id'
,
current_image_id
)
tf
.
summary
.
scalar
(
'current_image_id'
,
current_image_id
)
progress
=
networks
.
compute_progress
(
progress
=
networks
.
compute_progress
(
...
@@ -329,6 +356,8 @@ def build_model(stage_id, real_images, **kwargs):
...
@@ -329,6 +356,8 @@ def build_model(stage_id, real_images, **kwargs):
########## Define train ops.
########## Define train ops.
gan_train_ops
,
optimizer_var_list
=
define_train_ops
(
gan_model
,
gan_loss
,
gan_train_ops
,
optimizer_var_list
=
define_train_ops
(
gan_model
,
gan_loss
,
**
kwargs
)
**
kwargs
)
gan_train_ops
=
gan_train_ops
.
_replace
(
global_step_inc_op
=
current_image_id_inc_op
)
########## Generator smoothing.
########## Generator smoothing.
generator_ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.999
)
generator_ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.999
)
...
@@ -339,11 +368,11 @@ def build_model(stage_id, real_images, **kwargs):
...
@@ -339,11 +368,11 @@ def build_model(stage_id, real_images, **kwargs):
pass
pass
model
=
Model
()
model
=
Model
()
model
.
resolution_schedule
=
resolution_schedule
model
.
stage_id
=
stage_id
model
.
stage_id
=
stage_id
model
.
batch_size
=
batch_size
model
.
resolution_schedule
=
resolution_schedule
model
.
num_images
=
num_images
model
.
num_images
=
num_images
model
.
num_blocks
=
num_blocks
model
.
num_blocks
=
num_blocks
model
.
global_step
=
global_step
model
.
current_image_id
=
current_image_id
model
.
current_image_id
=
current_image_id
model
.
progress
=
progress
model
.
progress
=
progress
model
.
num_filters_fn
=
_num_filters_fn
model
.
num_filters_fn
=
_num_filters_fn
...
@@ -380,7 +409,6 @@ def add_model_summaries(model, **kwargs):
...
@@ -380,7 +409,6 @@ def add_model_summaries(model, **kwargs):
model: An model object having all information of progressive GAN model,
model: An model object having all information of progressive GAN model,
e.g. the return of build_model().
e.g. the return of build_model().
**kwargs: A dictionary of
**kwargs: A dictionary of
'batch_size': Number of training images in each minibatch.
'fake_grid_size': The fake image grid size for summaries.
'fake_grid_size': The fake image grid size for summaries.
'interp_grid_size': The latent space interpolated image grid size for
'interp_grid_size': The latent space interpolated image grid size for
summaries.
summaries.
...
@@ -431,7 +459,7 @@ def add_model_summaries(model, **kwargs):
...
@@ -431,7 +459,7 @@ def add_model_summaries(model, **kwargs):
num_channels
=
colors
),
num_channels
=
colors
),
max_outputs
=
1
)
max_outputs
=
1
)
real_grid_size
=
int
(
np
.
sqrt
(
kwargs
[
'
batch_size
'
]
))
real_grid_size
=
int
(
np
.
sqrt
(
model
.
batch_size
))
tf
.
summary
.
image
(
tf
.
summary
.
image
(
'real_images_blend'
,
'real_images_blend'
,
tfgan
.
eval
.
eval_utils
.
image_grid
(
tfgan
.
eval
.
eval_utils
.
image_grid
(
...
@@ -517,11 +545,10 @@ def make_status_message(model):
...
@@ -517,11 +545,10 @@ def make_status_message(model):
"""Makes a string `Tensor` of training status."""
"""Makes a string `Tensor` of training status."""
return
tf
.
string_join
(
return
tf
.
string_join
(
[
[
'Starting train step: '
,
'Starting train step: current_image_id: '
,
tf
.
as_string
(
model
.
global_step
),
', current_image_id: '
,
tf
.
as_string
(
model
.
current_image_id
),
', progress: '
,
tf
.
as_string
(
model
.
current_image_id
),
', progress: '
,
tf
.
as_string
(
model
.
progress
),
', num_blocks: {}'
.
format
(
tf
.
as_string
(
model
.
progress
),
', num_blocks: {}'
.
format
(
model
.
num_blocks
)
model
.
num_blocks
)
,
', batch_size: {}'
.
format
(
model
.
batch_size
)
],
],
name
=
'status_message'
)
name
=
'status_message'
)
...
@@ -541,8 +568,6 @@ def train(model, **kwargs):
...
@@ -541,8 +568,6 @@ def train(model, **kwargs):
Returns:
Returns:
None.
None.
"""
"""
batch_size
=
kwargs
[
'batch_size'
]
logging
.
info
(
'stage_id=%d, num_blocks=%d, num_images=%d'
,
model
.
stage_id
,
logging
.
info
(
'stage_id=%d, num_blocks=%d, num_images=%d'
,
model
.
stage_id
,
model
.
num_blocks
,
model
.
num_images
)
model
.
num_blocks
,
model
.
num_images
)
...
@@ -553,7 +578,7 @@ def train(model, **kwargs):
...
@@ -553,7 +578,7 @@ def train(model, **kwargs):
logdir
=
make_train_sub_dir
(
model
.
stage_id
,
**
kwargs
),
logdir
=
make_train_sub_dir
(
model
.
stage_id
,
**
kwargs
),
get_hooks_fn
=
tfgan
.
get_sequential_train_hooks
(
tfgan
.
GANTrainSteps
(
1
,
1
)),
get_hooks_fn
=
tfgan
.
get_sequential_train_hooks
(
tfgan
.
GANTrainSteps
(
1
,
1
)),
hooks
=
[
hooks
=
[
tf
.
train
.
StopAtStepHook
(
last_step
=
model
.
num_images
//
batch_size
),
tf
.
train
.
StopAtStepHook
(
last_step
=
model
.
num_images
),
tf
.
train
.
LoggingTensorHook
(
tf
.
train
.
LoggingTensorHook
(
[
make_status_message
(
model
)],
every_n_iter
=
10
)
[
make_status_message
(
model
)],
every_n_iter
=
10
)
],
],
...
@@ -561,4 +586,4 @@ def train(model, **kwargs):
...
@@ -561,4 +586,4 @@ def train(model, **kwargs):
is_chief
=
(
kwargs
[
'task'
]
==
0
),
is_chief
=
(
kwargs
[
'task'
]
==
0
),
scaffold
=
scaffold
,
scaffold
=
scaffold
,
save_checkpoint_secs
=
600
,
save_checkpoint_secs
=
600
,
save_summaries_steps
=
(
kwargs
[
'save_summaries_num_images'
]
//
batch_size
))
save_summaries_steps
=
(
kwargs
[
'save_summaries_num_images'
]))
research/gan/progressive_gan/train_main.py
View file @
54a5a577
...
@@ -48,6 +48,12 @@ flags.DEFINE_integer('scale_base', 2, 'Resolution multiplier.')
...
@@ -48,6 +48,12 @@ flags.DEFINE_integer('scale_base', 2, 'Resolution multiplier.')
flags
.
DEFINE_integer
(
'num_resolutions'
,
4
,
'Number of progressive resolutions.'
)
flags
.
DEFINE_integer
(
'num_resolutions'
,
4
,
'Number of progressive resolutions.'
)
flags
.
DEFINE_list
(
'batch_size_schedule'
,
[
8
,
8
,
4
],
'A list of batch sizes for each resolution, if '
'len(batch_size_schedule) < num_resolutions, pad the schedule in the '
'beginning with the first batch size.'
)
flags
.
DEFINE_integer
(
'kernel_size'
,
3
,
'Convolution kernel size.'
)
flags
.
DEFINE_integer
(
'kernel_size'
,
3
,
'Convolution kernel size.'
)
flags
.
DEFINE_integer
(
'colors'
,
3
,
'Number of image channels.'
)
flags
.
DEFINE_integer
(
'colors'
,
3
,
'Number of image channels.'
)
...
@@ -55,8 +61,6 @@ flags.DEFINE_integer('colors', 3, 'Number of image channels.')
...
@@ -55,8 +61,6 @@ flags.DEFINE_integer('colors', 3, 'Number of image channels.')
flags
.
DEFINE_bool
(
'to_rgb_use_tanh_activation'
,
False
,
flags
.
DEFINE_bool
(
'to_rgb_use_tanh_activation'
,
False
,
'Whether to apply tanh activation when output rgb.'
)
'Whether to apply tanh activation when output rgb.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
8
,
'Number of images in each batch.'
)
flags
.
DEFINE_integer
(
'stable_stage_num_images'
,
1000
,
flags
.
DEFINE_integer
(
'stable_stage_num_images'
,
1000
,
'Number of images in the stable stage.'
)
'Number of images in the stable stage.'
)
...
@@ -123,11 +127,10 @@ def _make_config_from_flags():
...
@@ -123,11 +127,10 @@ def _make_config_from_flags():
for
flag
in
FLAGS
.
get_key_flags_for_module
(
sys
.
argv
[
0
])])
for
flag
in
FLAGS
.
get_key_flags_for_module
(
sys
.
argv
[
0
])])
def
_provide_real_images
(
**
kwargs
):
def
_provide_real_images
(
batch_size
,
**
kwargs
):
"""Provides real images."""
"""Provides real images."""
dataset_name
=
kwargs
.
get
(
'dataset_name'
)
dataset_name
=
kwargs
.
get
(
'dataset_name'
)
dataset_file_pattern
=
kwargs
.
get
(
'dataset_file_pattern'
)
dataset_file_pattern
=
kwargs
.
get
(
'dataset_file_pattern'
)
batch_size
=
kwargs
[
'batch_size'
]
colors
=
kwargs
[
'colors'
]
colors
=
kwargs
[
'colors'
]
final_height
,
final_width
=
train
.
make_resolution_schedule
(
final_height
,
final_width
=
train
.
make_resolution_schedule
(
**
kwargs
).
final_resolutions
**
kwargs
).
final_resolutions
...
@@ -156,12 +159,13 @@ def main(_):
...
@@ -156,12 +159,13 @@ def main(_):
logging
.
info
(
'
\n
'
.
join
([
'{}={}'
.
format
(
k
,
v
)
for
k
,
v
in
config
.
iteritems
()]))
logging
.
info
(
'
\n
'
.
join
([
'{}={}'
.
format
(
k
,
v
)
for
k
,
v
in
config
.
iteritems
()]))
for
stage_id
in
train
.
get_stage_ids
(
**
config
):
for
stage_id
in
train
.
get_stage_ids
(
**
config
):
batch_size
=
train
.
get_batch_size
(
stage_id
,
**
config
)
tf
.
reset_default_graph
()
tf
.
reset_default_graph
()
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
real_images
=
None
real_images
=
None
with
tf
.
device
(
'/cpu:0'
),
tf
.
name_scope
(
'inputs'
):
with
tf
.
device
(
'/cpu:0'
),
tf
.
name_scope
(
'inputs'
):
real_images
=
_provide_real_images
(
**
config
)
real_images
=
_provide_real_images
(
batch_size
,
**
config
)
model
=
train
.
build_model
(
stage_id
,
real_images
,
**
config
)
model
=
train
.
build_model
(
stage_id
,
batch_size
,
real_images
,
**
config
)
train
.
add_model_summaries
(
model
,
**
config
)
train
.
add_model_summaries
(
model
,
**
config
)
train
.
train
(
model
,
**
config
)
train
.
train
(
model
,
**
config
)
...
...
research/gan/progressive_gan/train_test.py
View file @
54a5a577
...
@@ -29,7 +29,7 @@ import train
...
@@ -29,7 +29,7 @@ import train
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
provide_random_data
(
batch_size
=
2
,
patch_size
=
8
,
colors
=
1
,
**
unused_kwargs
):
def
provide_random_data
(
batch_size
=
2
,
patch_size
=
4
,
colors
=
1
,
**
unused_kwargs
):
return
tf
.
random_normal
([
batch_size
,
patch_size
,
patch_size
,
colors
])
return
tf
.
random_normal
([
batch_size
,
patch_size
,
patch_size
,
colors
])
...
@@ -37,19 +37,19 @@ class TrainTest(absltest.TestCase):
...
@@ -37,19 +37,19 @@ class TrainTest(absltest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
_config
=
{
self
.
_config
=
{
'start_height'
:
4
,
'start_height'
:
2
,
'start_width'
:
4
,
'start_width'
:
2
,
'scale_base'
:
2
,
'scale_base'
:
2
,
'num_resolutions'
:
2
,
'num_resolutions'
:
2
,
'batch_size_schedule'
:
[
2
],
'colors'
:
1
,
'colors'
:
1
,
'to_rgb_use_tanh_activation'
:
True
,
'to_rgb_use_tanh_activation'
:
True
,
'kernel_size'
:
3
,
'kernel_size'
:
3
,
'batch_size'
:
2
,
'stable_stage_num_images'
:
1
,
'stable_stage_num_images'
:
4
,
'transition_stage_num_images'
:
1
,
'transition_stage_num_images'
:
4
,
'total_num_images'
:
3
,
'total_num_images'
:
12
,
'save_summaries_num_images'
:
2
,
'save_summaries_num_images'
:
4
,
'latent_vector_size'
:
2
,
'latent_vector_size'
:
8
,
'fmap_base'
:
8
,
'fmap_base'
:
8
,
'fmap_decay'
:
1.0
,
'fmap_decay'
:
1.0
,
'fmap_max'
:
8
,
'fmap_max'
:
8
,
...
@@ -73,12 +73,21 @@ class TrainTest(absltest.TestCase):
...
@@ -73,12 +73,21 @@ class TrainTest(absltest.TestCase):
tf
.
gfile
.
MakeDirs
(
train_root_dir
)
tf
.
gfile
.
MakeDirs
(
train_root_dir
)
for
stage_id
in
train
.
get_stage_ids
(
**
self
.
_config
):
for
stage_id
in
train
.
get_stage_ids
(
**
self
.
_config
):
batch_size
=
train
.
get_batch_size
(
stage_id
,
**
self
.
_config
)
tf
.
reset_default_graph
()
tf
.
reset_default_graph
()
real_images
=
provide_random_data
()
real_images
=
provide_random_data
(
batch_size
=
batch_size
)
model
=
train
.
build_model
(
stage_id
,
real_images
,
**
self
.
_config
)
model
=
train
.
build_model
(
stage_id
,
batch_size
,
real_images
,
**
self
.
_config
)
train
.
add_model_summaries
(
model
,
**
self
.
_config
)
train
.
add_model_summaries
(
model
,
**
self
.
_config
)
train
.
train
(
model
,
**
self
.
_config
)
train
.
train
(
model
,
**
self
.
_config
)
def
test_get_batch_size
(
self
):
config
=
{
'num_resolutions'
:
5
,
'batch_size_schedule'
:
[
8
,
4
,
2
]}
# batch_size_schedule is expanded to [8, 8, 8, 4, 2]
# At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
for
i
,
expected_batch_size
in
enumerate
([
8
,
8
,
8
,
8
,
8
,
4
,
4
,
2
,
2
]):
self
.
assertEqual
(
train
.
get_batch_size
(
i
,
**
config
),
expected_batch_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
absltest
.
main
()
absltest
.
main
()
research/gan/stargan/data_provider.py
0 → 100644
View file @
54a5a577
"""StarGAN data provider."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
data_provider
def
provide_data
(
image_file_patterns
,
batch_size
,
patch_size
):
"""Data provider wrapper on for the data_provider in gan/cyclegan.
Args:
image_file_patterns: A list of file pattern globs.
batch_size: Python int. Batch size.
patch_size: Python int. The patch size to extract.
Returns:
List of `Tensor` of shape (N, H, W, C) representing the images.
List of `Tensor` of shape (N, num_domains) representing the labels.
"""
images
=
data_provider
.
provide_custom_data
(
image_file_patterns
,
batch_size
=
batch_size
,
patch_size
=
patch_size
)
num_domains
=
len
(
images
)
labels
=
[
tf
.
one_hot
([
idx
]
*
batch_size
,
num_domains
)
for
idx
in
range
(
num_domains
)]
return
images
,
labels
research/gan/stargan/data_provider_test.py
0 → 100644
View file @
54a5a577
"""Tests for google3.third_party.tensorflow_models.gan.stargan.data_provider."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
google3.testing.pybase
import
googletest
import
data_provider
mock
=
tf
.
test
.
mock
class
DataProviderTest
(
googletest
.
TestCase
):
@
mock
.
patch
.
object
(
data_provider
.
data_provider
,
'provide_custom_data'
,
autospec
=
True
)
def
test_data_provider
(
self
,
mock_provide_custom_data
):
batch_size
=
2
patch_size
=
8
num_domains
=
3
images_shape
=
[
batch_size
,
patch_size
,
patch_size
,
3
]
mock_provide_custom_data
.
return_value
=
[
tf
.
zeros
(
images_shape
)
for
_
in
range
(
num_domains
)
]
images
,
labels
=
data_provider
.
provide_data
(
image_file_patterns
=
None
,
batch_size
=
batch_size
,
patch_size
=
patch_size
)
self
.
assertEqual
(
num_domains
,
len
(
images
))
self
.
assertEqual
(
num_domains
,
len
(
labels
))
for
label
in
labels
:
self
.
assertListEqual
([
batch_size
,
num_domains
],
label
.
shape
.
as_list
())
for
image
in
images
:
self
.
assertListEqual
(
images_shape
,
image
.
shape
.
as_list
())
if
__name__
==
'__main__'
:
googletest
.
main
()
research/gan/stargan/layers.py
0 → 100644
View file @
54a5a577
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for a StarGAN model.
This module contains basic layers to build a StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorvh implementation.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
ops
def
generator_down_sample
(
input_net
,
final_num_outputs
=
256
):
"""Down-sampling module in Generator.
Down sampling pathway of the Generator Architecture:
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L32
Notes:
We require dimension 1 and dimension 2 of the input_net to be fully defined
for the correct down sampling.
Args:
input_net: Tensor of shape (batch_size, h, w, c + num_class).
final_num_outputs: (int) Number of hidden unit for the final layer.
Returns:
Tensor of shape (batch_size, h / 4, w / 4, 256).
Raises:
ValueError: If final_num_outputs are not divisible by 4,
or input_net does not have a rank of 4,
or dimension 1 and dimension 2 of input_net are not defined at graph
construction time,
or dimension 1 and dimension 2 of input_net are not divisible by 4.
"""
if
final_num_outputs
%
4
!=
0
:
raise
ValueError
(
'Final number outputs need to be divisible by 4.'
)
# Check the rank of input_net.
input_net
.
shape
.
assert_has_rank
(
4
)
# Check dimension 1 and dimension 2 are defined and divisible by 4.
if
input_net
.
shape
[
1
]:
if
input_net
.
shape
[
1
]
%
4
!=
0
:
raise
ValueError
(
'Dimension 1 of the input should be divisible by 4, but is {} '
'instead.'
.
format
(
input_net
.
shape
[
1
]))
else
:
raise
ValueError
(
'Dimension 1 of the input should be explicitly defined.'
)
# Check dimension 1 and dimension 2 are defined and divisible by 4.
if
input_net
.
shape
[
2
]:
if
input_net
.
shape
[
2
]
%
4
!=
0
:
raise
ValueError
(
'Dimension 2 of the input should be divisible by 4, but is {} '
'instead.'
.
format
(
input_net
.
shape
[
2
]))
else
:
raise
ValueError
(
'Dimension 2 of the input should be explicitly defined.'
)
with
tf
.
variable_scope
(
'generator_down_sample'
):
with
tf
.
contrib
.
framework
.
arg_scope
(
[
tf
.
contrib
.
layers
.
conv2d
],
padding
=
'VALID'
,
biases_initializer
=
None
,
normalizer_fn
=
tf
.
contrib
.
layers
.
instance_norm
,
activation_fn
=
tf
.
nn
.
relu
):
down_sample
=
ops
.
pad
(
input_net
,
3
)
down_sample
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
down_sample
,
num_outputs
=
final_num_outputs
/
4
,
kernel_size
=
7
,
stride
=
1
,
scope
=
'conv_0'
)
down_sample
=
ops
.
pad
(
down_sample
,
1
)
down_sample
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
down_sample
,
num_outputs
=
final_num_outputs
/
2
,
kernel_size
=
4
,
stride
=
2
,
scope
=
'conv_1'
)
down_sample
=
ops
.
pad
(
down_sample
,
1
)
output_net
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
down_sample
,
num_outputs
=
final_num_outputs
,
kernel_size
=
4
,
stride
=
2
,
scope
=
'conv_2'
)
return
output_net
def
_residual_block
(
input_net
,
num_outputs
,
kernel_size
,
stride
=
1
,
padding_size
=
0
,
activation_fn
=
tf
.
nn
.
relu
,
normalizer_fn
=
None
,
name
=
'residual_block'
):
"""Residual Block.
Input Tensor X - > Conv1 -> IN -> ReLU -> Conv2 -> IN + X
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L7
Args:
input_net: Tensor as input.
num_outputs: (int) number of output channels for Convolution.
kernel_size: (int) size of the square kernel for Convolution.
stride: (int) stride for Convolution. Default to 1.
padding_size: (int) padding size for Convolution. Default to 0.
activation_fn: Activation function.
normalizer_fn: Normalization function.
name: Name scope
Returns:
Residual Tensor with the same shape as the input tensor.
"""
with
tf
.
variable_scope
(
name
):
with
tf
.
contrib
.
framework
.
arg_scope
(
[
tf
.
contrib
.
layers
.
conv2d
],
num_outputs
=
num_outputs
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
'VALID'
,
normalizer_fn
=
normalizer_fn
,
activation_fn
=
None
):
res_block
=
ops
.
pad
(
input_net
,
padding_size
)
res_block
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
res_block
,
scope
=
'conv_0'
)
res_block
=
activation_fn
(
res_block
,
name
=
'activation_0'
)
res_block
=
ops
.
pad
(
res_block
,
padding_size
)
res_block
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
res_block
,
scope
=
'conv_1'
)
output_net
=
res_block
+
input_net
return
output_net
def
generator_bottleneck
(
input_net
,
residual_block_num
=
6
,
num_outputs
=
256
):
"""Bottleneck module in Generator.
Residual bottleneck pathway in Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L40
Args:
input_net: Tensor of shape (batch_size, h / 4, w / 4, 256).
residual_block_num: (int) Number of residual_block_num. Default to 6 per the
original implementation.
num_outputs: (int) Number of hidden unit in the residual bottleneck. Default
to 256 per the original implementation.
Returns:
Tensor of shape (batch_size, h / 4, w / 4, 256).
Raises:
ValueError: If the rank of the input tensor is not 4,
or the last channel of the input_tensor is not explicitly defined,
or the last channel of the input_tensor is not the same as num_outputs.
"""
# Check the rank of input_net.
input_net
.
shape
.
assert_has_rank
(
4
)
# Check dimension 4 of the input_net.
if
input_net
.
shape
[
-
1
]:
if
input_net
.
shape
[
-
1
]
!=
num_outputs
:
raise
ValueError
(
'The last dimension of the input_net should be the same as '
'num_outputs: but {} vs. {} instead.'
.
format
(
input_net
.
shape
[
-
1
],
num_outputs
))
else
:
raise
ValueError
(
'The last dimension of the input_net should be explicitly defined.'
)
with
tf
.
variable_scope
(
'generator_bottleneck'
):
bottleneck
=
input_net
for
i
in
range
(
residual_block_num
):
bottleneck
=
_residual_block
(
input_net
=
bottleneck
,
num_outputs
=
num_outputs
,
kernel_size
=
3
,
stride
=
1
,
padding_size
=
1
,
activation_fn
=
tf
.
nn
.
relu
,
normalizer_fn
=
tf
.
contrib
.
layers
.
instance_norm
,
name
=
'residual_block_{}'
.
format
(
i
))
return
bottleneck
def
generator_up_sample
(
input_net
,
num_outputs
):
"""Up-sampling module in Generator.
Up sampling path for image generation in the Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L44
Args:
input_net: Tensor of shape (batch_size, h / 4, w / 4, 256).
num_outputs: (int) Number of channel for the output tensor.
Returns:
Tensor of shape (batch_size, h, w, num_outputs).
"""
with
tf
.
variable_scope
(
'generator_up_sample'
):
with
tf
.
contrib
.
framework
.
arg_scope
(
[
tf
.
contrib
.
layers
.
conv2d_transpose
],
kernel_size
=
4
,
stride
=
2
,
padding
=
'VALID'
,
normalizer_fn
=
tf
.
contrib
.
layers
.
instance_norm
,
activation_fn
=
tf
.
nn
.
relu
):
up_sample
=
tf
.
contrib
.
layers
.
conv2d_transpose
(
inputs
=
input_net
,
num_outputs
=
128
,
scope
=
'deconv_0'
)
up_sample
=
up_sample
[:,
1
:
-
1
,
1
:
-
1
,
:]
up_sample
=
tf
.
contrib
.
layers
.
conv2d_transpose
(
inputs
=
up_sample
,
num_outputs
=
64
,
scope
=
'deconv_1'
)
up_sample
=
up_sample
[:,
1
:
-
1
,
1
:
-
1
,
:]
output_net
=
ops
.
pad
(
up_sample
,
3
)
output_net
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
output_net
,
num_outputs
=
num_outputs
,
kernel_size
=
7
,
stride
=
1
,
padding
=
'VALID'
,
activation_fn
=
tf
.
nn
.
tanh
,
normalizer_fn
=
None
,
biases_initializer
=
None
,
scope
=
'conv_0'
)
return
output_net
def
discriminator_input_hidden
(
input_net
,
hidden_layer
=
6
,
init_num_outputs
=
64
):
"""Input Layer + Hidden Layer in the Discriminator.
Feature extraction pathway in the Discriminator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L68
Args:
input_net: Tensor of shape (batch_size, h, w, 3) as batch of images.
hidden_layer: (int) Number of hidden layers. Default to 6 per the original
implementation.
init_num_outputs: (int) Number of hidden unit in the first hidden layer. The
number of hidden unit double after each layer. Default to 64 per the
original implementation.
Returns:
Tensor of shape (batch_size, h / 64, w / 64, 2048) as features.
"""
num_outputs
=
init_num_outputs
with
tf
.
variable_scope
(
'discriminator_input_hidden'
):
hidden
=
input_net
for
i
in
range
(
hidden_layer
):
hidden
=
ops
.
pad
(
hidden
,
1
)
hidden
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
hidden
,
num_outputs
=
num_outputs
,
kernel_size
=
4
,
stride
=
2
,
padding
=
'VALID'
,
activation_fn
=
None
,
normalizer_fn
=
None
,
scope
=
'conv_{}'
.
format
(
i
))
hidden
=
tf
.
nn
.
leaky_relu
(
hidden
,
alpha
=
0.01
)
num_outputs
=
2
*
num_outputs
return
hidden
def
discriminator_output_source
(
input_net
):
"""Output Layer for Source in the Discriminator.
Determine if the image is real/fake based on the feature extracted. We follow
the original paper design where the output is not a simple (batch_size) shape
Tensor but rather a (batch_size, 2, 2, 2048) shape Tensor. We will get the
correct shape later when we piece things together.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L79
Args:
input_net: Tensor of shape (batch_size, h / 64, w / 64, 2048) as features.
Returns:
Tensor of shape (batch_size, h / 64, w / 64, 1) as the score.
"""
with
tf
.
variable_scope
(
'discriminator_output_source'
):
output_src
=
ops
.
pad
(
input_net
,
1
)
output_src
=
tf
.
contrib
.
layers
.
conv2d
(
inputs
=
output_src
,
num_outputs
=
1
,
kernel_size
=
3
,
stride
=
1
,
padding
=
'VALID'
,
activation_fn
=
None
,
normalizer_fn
=
None
,
biases_initializer
=
None
,
scope
=
'conv'
)
return
output_src
def
discriminator_output_class
(
input_net
,
class_num
):
"""Output Layer for Domain Classification in the Discriminator.
The original paper use convolution layer where the kernel size is the height
and width of the Tensor. We use an equivalent operation here where we first
flatten the Tensor to shape (batch_size, K) and a fully connected layer.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L80https
Args:
input_net: Tensor of shape (batch_size, h / 64, w / 64, 2028).
class_num: Number of output classes to be predicted.
Returns:
Tensor of shape (batch_size, class_num).
"""
with
tf
.
variable_scope
(
'discriminator_output_class'
):
output_cls
=
tf
.
contrib
.
layers
.
flatten
(
input_net
,
scope
=
'flatten'
)
output_cls
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
output_cls
,
num_outputs
=
class_num
,
activation_fn
=
None
,
normalizer_fn
=
None
,
biases_initializer
=
None
,
scope
=
'fully_connected'
)
return
output_cls
research/gan/stargan/layers_test.py
0 → 100644
View file @
54a5a577
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
tensorflow
as
tf
import
layers
class
LayersTest
(
tf
.
test
.
TestCase
):
def
test_residual_block
(
self
):
n
=
2
h
=
32
w
=
h
c
=
256
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
_residual_block
(
input_net
=
input_tensor
,
num_outputs
=
c
,
kernel_size
=
3
,
stride
=
1
,
padding_size
=
1
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
h
,
w
,
c
),
output
.
shape
)
def
test_generator_down_sample
(
self
):
n
=
2
h
=
128
w
=
h
c
=
3
+
3
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
generator_down_sample
(
input_tensor
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
h
//
4
,
w
//
4
,
256
),
output
.
shape
)
def
test_generator_bottleneck
(
self
):
n
=
2
h
=
32
w
=
h
c
=
256
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
generator_bottleneck
(
input_tensor
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
h
,
w
,
c
),
output
.
shape
)
def
test_generator_up_sample
(
self
):
n
=
2
h
=
32
w
=
h
c
=
256
c_out
=
3
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
generator_up_sample
(
input_tensor
,
c_out
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
h
*
4
,
w
*
4
,
c_out
),
output
.
shape
)
def
test_discriminator_input_hidden
(
self
):
n
=
2
h
=
128
w
=
128
c
=
3
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
discriminator_input_hidden
(
input_tensor
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
2
,
2
,
2048
),
output
.
shape
)
def
test_discriminator_output_source
(
self
):
n
=
2
h
=
2
w
=
2
c
=
2048
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
discriminator_output_source
(
input_tensor
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
h
,
w
,
1
),
output
.
shape
)
def
test_discriminator_output_class
(
self
):
n
=
2
h
=
2
w
=
2
c
=
2048
num_domain
=
3
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_tensor
=
layers
.
discriminator_output_class
(
input_tensor
,
num_domain
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
num_domain
),
output
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/stargan/network.py
0 → 100644
View file @
54a5a577
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Neural network for a StarGAN model.
This module contains the Generator and Discriminator Neural Network to build a
StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorch implementation.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
layers
import
ops
def
generator
(
inputs
,
targets
):
"""Generator module.
Piece everything together for the Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22
Args:
inputs: Tensor of shape (batch_size, h, w, c) representing the
images/information that we want to transform.
targets: Tensor of shape (batch_size, num_domains) representing the target
domain the generator should transform the image/information to.
Returns:
Tensor of shape (batch_size, h, w, c) as the inputs.
"""
with
tf
.
variable_scope
(
'generator'
):
input_with_condition
=
ops
.
condition_input_with_pixel_padding
(
inputs
,
targets
)
down_sample
=
layers
.
generator_down_sample
(
input_with_condition
)
bottleneck
=
layers
.
generator_bottleneck
(
down_sample
)
up_sample
=
layers
.
generator_up_sample
(
bottleneck
,
inputs
.
shape
[
-
1
])
return
up_sample
def
discriminator
(
input_net
,
class_num
):
"""Discriminator Module.
Piece everything together and reshape the output source tensor
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L63
Notes:
The PyTorch Version run the reduce_mean operation later in their solver:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/solver.py#L245
Args:
input_net: Tensor of shape (batch_size, h, w, c) as batch of images.
class_num: (int) number of domain to be predicted
Returns:
output_src: Tensor of shape (batch_size) where each value is a logit
representing whether the image is real of fake.
output_cls: Tensor of shape (batch_size, class_um) where each value is a
logit representing whether the image is in the associated domain.
"""
with
tf
.
variable_scope
(
'discriminator'
):
hidden
=
layers
.
discriminator_input_hidden
(
input_net
)
output_src
=
layers
.
discriminator_output_source
(
hidden
)
output_src
=
tf
.
contrib
.
layers
.
flatten
(
output_src
)
output_src
=
tf
.
reduce_mean
(
output_src
,
axis
=
1
)
output_cls
=
layers
.
discriminator_output_class
(
hidden
,
class_num
)
return
output_src
,
output_cls
research/gan/stargan/network_test.py
0 → 100644
View file @
54a5a577
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
tensorflow
as
tf
import
network
class
NetworkTest
(
tf
.
test
.
TestCase
):
def
test_generator
(
self
):
n
=
2
h
=
128
w
=
h
c
=
4
class_num
=
3
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
target_tensor
=
tf
.
random_uniform
((
n
,
class_num
))
output_tensor
=
network
.
generator
(
input_tensor
,
target_tensor
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output
=
sess
.
run
(
output_tensor
)
self
.
assertTupleEqual
((
n
,
h
,
w
,
c
),
output
.
shape
)
def
test_discriminator
(
self
):
n
=
2
h
=
128
w
=
h
c
=
3
class_num
=
3
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
output_src_tensor
,
output_cls_tensor
=
network
.
discriminator
(
input_tensor
,
class_num
)
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
output_src
,
output_cls
=
sess
.
run
([
output_src_tensor
,
output_cls_tensor
])
self
.
assertEqual
(
1
,
len
(
output_src
.
shape
))
self
.
assertEqual
(
n
,
output_src
.
shape
[
0
])
self
.
assertTupleEqual
((
n
,
class_num
),
output_cls
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/stargan/ops.py
0 → 100644
View file @
54a5a577
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for a StarGAN model.
This module contains basic ops to build a StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorvh implementation.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
_padding_arg
(
h
,
w
,
input_format
):
"""Calculate the padding shape for tf.pad().
Args:
h: (int) padding on the height dim.
w: (int) padding on the width dim.
input_format: (string) the input format as in 'NHWC' or 'HWC'.
Raises:
ValueError: If input_format is not 'NHWC' or 'HWC'.
Returns:
A two dimension array representing the padding argument.
"""
if
input_format
==
'NHWC'
:
return
[[
0
,
0
],
[
h
,
h
],
[
w
,
w
],
[
0
,
0
]]
elif
input_format
==
'HWC'
:
return
[[
h
,
h
],
[
w
,
w
],
[
0
,
0
]]
else
:
raise
ValueError
(
'Input Format %s is not supported.'
%
input_format
)
def
pad
(
input_net
,
padding_size
):
"""Padding the tensor with padding_size on both the height and width dim.
Args:
input_net: Tensor in 3D ('HWC') or 4D ('NHWC').
padding_size: (int) the size of the padding.
Notes:
Original StarGAN use zero padding instead of mirror padding.
Raises:
ValueError: If input_net Tensor is not 3D or 4D.
Returns:
Tensor with same rank as input_net but with padding on the height and width
dim.
"""
if
len
(
input_net
.
shape
)
==
4
:
return
tf
.
pad
(
input_net
,
_padding_arg
(
padding_size
,
padding_size
,
'NHWC'
))
elif
len
(
input_net
.
shape
)
==
3
:
return
tf
.
pad
(
input_net
,
_padding_arg
(
padding_size
,
padding_size
,
'HWC'
))
else
:
raise
ValueError
(
'The input tensor need to be either 3D or 4D.'
)
def
condition_input_with_pixel_padding
(
input_tensor
,
condition_tensor
):
"""Pad image tensor with condition tensor as additional color channel.
Args:
input_tensor: Tensor of shape (batch_size, h, w, c) representing images.
condition_tensor: Tensor of shape (batch_size, num_domains) representing the
associated domain for the image in input_tensor.
Returns:
Tensor of shape (batch_size, h, w, c + num_domains) representing the
conditioned data.
Raises:
ValueError: If `input_tensor` isn't rank 4.
ValueError: If `condition_tensor` isn't rank 2.
ValueError: If dimension 1 of the input_tensor and condition_tensor is not
the same.
"""
input_tensor
.
shape
.
assert_has_rank
(
4
)
condition_tensor
.
shape
.
assert_has_rank
(
2
)
input_tensor
.
shape
[:
1
].
assert_is_compatible_with
(
condition_tensor
.
shape
[:
1
])
condition_tensor
=
tf
.
expand_dims
(
condition_tensor
,
axis
=
1
)
condition_tensor
=
tf
.
expand_dims
(
condition_tensor
,
axis
=
1
)
condition_tensor
=
tf
.
tile
(
condition_tensor
,
[
1
,
input_tensor
.
shape
[
1
],
input_tensor
.
shape
[
2
],
1
])
return
tf
.
concat
([
input_tensor
,
condition_tensor
],
-
1
)
research/gan/stargan/ops_test.py
0 → 100644
View file @
54a5a577
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
tensorflow
as
tf
import
ops
class
OpsTest
(
tf
.
test
.
TestCase
):
def
test_padding_arg
(
self
):
pad_h
=
2
pad_w
=
3
self
.
assertListEqual
([[
0
,
0
],
[
pad_h
,
pad_h
],
[
pad_w
,
pad_w
],
[
0
,
0
]],
ops
.
_padding_arg
(
pad_h
,
pad_w
,
'NHWC'
))
def
test_padding_arg_specify_format
(
self
):
pad_h
=
2
pad_w
=
3
self
.
assertListEqual
([[
pad_h
,
pad_h
],
[
pad_w
,
pad_w
],
[
0
,
0
]],
ops
.
_padding_arg
(
pad_h
,
pad_w
,
'HWC'
))
def
test_padding_arg_invalid_format
(
self
):
pad_h
=
2
pad_w
=
3
with
self
.
assertRaises
(
ValueError
):
ops
.
_padding_arg
(
pad_h
,
pad_w
,
'INVALID'
)
def
test_padding
(
self
):
n
=
2
h
=
128
w
=
64
c
=
3
pad
=
3
test_input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
test_output_tensor
=
ops
.
pad
(
test_input_tensor
,
padding_size
=
pad
)
with
self
.
test_session
()
as
sess
:
output
=
sess
.
run
(
test_output_tensor
)
self
.
assertTupleEqual
((
n
,
h
+
pad
*
2
,
w
+
pad
*
2
,
c
),
output
.
shape
)
def
test_padding_with_3D_tensor
(
self
):
h
=
128
w
=
64
c
=
3
pad
=
3
test_input_tensor
=
tf
.
random_uniform
((
h
,
w
,
c
))
test_output_tensor
=
ops
.
pad
(
test_input_tensor
,
padding_size
=
pad
)
with
self
.
test_session
()
as
sess
:
output
=
sess
.
run
(
test_output_tensor
)
self
.
assertTupleEqual
((
h
+
pad
*
2
,
w
+
pad
*
2
,
c
),
output
.
shape
)
def
test_padding_with_tensor_of_invalid_shape
(
self
):
n
=
2
invalid_rank
=
1
h
=
128
w
=
64
c
=
3
pad
=
3
test_input_tensor
=
tf
.
random_uniform
((
n
,
invalid_rank
,
h
,
w
,
c
))
with
self
.
assertRaises
(
ValueError
):
ops
.
pad
(
test_input_tensor
,
padding_size
=
pad
)
def
test_condition_input_with_pixel_padding
(
self
):
n
=
2
h
=
128
w
=
h
c
=
3
num_label
=
5
input_tensor
=
tf
.
random_uniform
((
n
,
h
,
w
,
c
))
label_tensor
=
tf
.
random_uniform
((
n
,
num_label
))
output_tensor
=
ops
.
condition_input_with_pixel_padding
(
input_tensor
,
label_tensor
)
with
self
.
test_session
()
as
sess
:
labels
,
outputs
=
sess
.
run
([
label_tensor
,
output_tensor
])
self
.
assertTupleEqual
((
n
,
h
,
w
,
c
+
num_label
),
outputs
.
shape
)
for
label
,
output
in
zip
(
labels
,
outputs
):
for
i
in
range
(
output
.
shape
[
0
]):
for
j
in
range
(
output
.
shape
[
1
]):
self
.
assertListEqual
(
label
.
tolist
(),
output
[
i
,
j
,
c
:].
tolist
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/stargan/train.py
0 → 100644
View file @
54a5a577
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a StarGAN model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
import
network
# FLAGS for data.
flags
.
DEFINE_multi_string
(
'image_file_patterns'
,
None
,
'List of file pattern for different domain of images. '
'(e.g.[
\'
black_hair
\'
,
\'
blond_hair
\'
,
\'
brown_hair
\'
]'
)
flags
.
DEFINE_integer
(
'batch_size'
,
6
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'patch_size'
,
128
,
'The patch size of images.'
)
flags
.
DEFINE_string
(
'train_log_dir'
,
'/tmp/stargan/'
,
'Directory where to write event logs.'
)
# FLAGS for training hyper-parameters.
flags
.
DEFINE_float
(
'generator_lr'
,
1e-4
,
'The generator learning rate.'
)
flags
.
DEFINE_float
(
'discriminator_lr'
,
1e-4
,
'The discriminator learning rate.'
)
flags
.
DEFINE_integer
(
'max_number_of_steps'
,
1000000
,
'The maximum number of gradient steps.'
)
flags
.
DEFINE_float
(
'adam_beta1'
,
0.5
,
'Adam Beta 1 for the Adam optimizer.'
)
flags
.
DEFINE_float
(
'adam_beta2'
,
0.999
,
'Adam Beta 2 for the Adam optimizer.'
)
flags
.
DEFINE_float
(
'gen_disc_step_ratio'
,
0.2
,
'Generator:Discriminator training step ratio.'
)
# FLAGS for distributed training.
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_integer
(
'ps_tasks'
,
0
,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.'
)
flags
.
DEFINE_integer
(
'task'
,
0
,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.'
)
FLAGS
=
flags
.
FLAGS
tfgan
=
tf
.
contrib
.
gan
def
_define_model
(
images
,
labels
):
"""Create the StarGAN Model.
Args:
images: `Tensor` or list of `Tensor` of shape (N, H, W, C).
labels: `Tensor` or list of `Tensor` of shape (N, num_domains).
Returns:
`StarGANModel` namedtuple.
"""
return
tfgan
.
stargan_model
(
generator_fn
=
network
.
generator
,
discriminator_fn
=
network
.
discriminator
,
input_data
=
images
,
input_data_domain_label
=
labels
)
def
_get_lr
(
base_lr
):
"""Returns a learning rate `Tensor`.
Args:
base_lr: A scalar float `Tensor` or a Python number. The base learning
rate.
Returns:
A scalar float `Tensor` of learning rate which equals `base_lr` when the
global training step is less than FLAGS.max_number_of_steps / 2, afterwards
it linearly decays to zero.
"""
global_step
=
tf
.
train
.
get_or_create_global_step
()
lr_constant_steps
=
FLAGS
.
max_number_of_steps
//
2
def
_lr_decay
():
return
tf
.
train
.
polynomial_decay
(
learning_rate
=
base_lr
,
global_step
=
(
global_step
-
lr_constant_steps
),
decay_steps
=
(
FLAGS
.
max_number_of_steps
-
lr_constant_steps
),
end_learning_rate
=
0.0
)
return
tf
.
cond
(
global_step
<
lr_constant_steps
,
lambda
:
base_lr
,
_lr_decay
)
def
_get_optimizer
(
gen_lr
,
dis_lr
):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
gen_opt
=
tf
.
train
.
AdamOptimizer
(
gen_lr
,
beta1
=
FLAGS
.
adam_beta1
,
beta2
=
FLAGS
.
adam_beta2
,
use_locking
=
True
)
dis_opt
=
tf
.
train
.
AdamOptimizer
(
dis_lr
,
beta1
=
FLAGS
.
adam_beta1
,
beta2
=
FLAGS
.
adam_beta2
,
use_locking
=
True
)
return
gen_opt
,
dis_opt
def
_define_train_ops
(
model
,
loss
):
"""Defines train ops that trains `stargan_model` with `stargan_loss`.
Args:
model: A `StarGANModel` namedtuple.
loss: A `StarGANLoss` namedtuple containing all losses for
`stargan_model`.
Returns:
A `GANTrainOps` namedtuple.
"""
gen_lr
=
_get_lr
(
FLAGS
.
generator_lr
)
dis_lr
=
_get_lr
(
FLAGS
.
discriminator_lr
)
gen_opt
,
dis_opt
=
_get_optimizer
(
gen_lr
,
dis_lr
)
train_ops
=
tfgan
.
gan_train_ops
(
model
,
loss
,
generator_optimizer
=
gen_opt
,
discriminator_optimizer
=
dis_opt
,
summarize_gradients
=
True
,
colocate_gradients_with_ops
=
True
,
aggregation_method
=
tf
.
AggregationMethod
.
EXPERIMENTAL_ACCUMULATE_N
)
tf
.
summary
.
scalar
(
'generator_lr'
,
gen_lr
)
tf
.
summary
.
scalar
(
'discriminator_lr'
,
dis_lr
)
return
train_ops
def
_define_train_step
():
"""Get the training step for generator and discriminator for each GAN step.
Returns:
GANTrainSteps namedtuple representing the training step configuration.
"""
if
FLAGS
.
gen_disc_step_ratio
<=
1
:
discriminator_step
=
int
(
1
/
FLAGS
.
gen_disc_step_ratio
)
return
tfgan
.
GANTrainSteps
(
1
,
discriminator_step
)
else
:
generator_step
=
int
(
FLAGS
.
gen_disc_step_ratio
)
return
tfgan
.
GANTrainSteps
(
generator_step
,
1
)
def
main
(
_
):
# Create the log_dir if not exist.
if
not
tf
.
gfile
.
Exists
(
FLAGS
.
train_log_dir
):
tf
.
gfile
.
MakeDirs
(
FLAGS
.
train_log_dir
)
# Shard the model to different parameter servers.
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
# Create the input dataset.
with
tf
.
name_scope
(
'inputs'
):
images
,
labels
=
data_provider
.
provide_data
(
FLAGS
.
image_file_patterns
,
FLAGS
.
batch_size
,
FLAGS
.
patch_size
)
# Define the model.
with
tf
.
name_scope
(
'model'
):
model
=
_define_model
(
images
,
labels
)
# Add image summary.
tfgan
.
eval
.
add_stargan_image_summaries
(
model
,
num_images
=
len
(
FLAGS
.
image_file_patterns
)
*
FLAGS
.
batch_size
,
display_diffs
=
True
)
# Define the model loss.
loss
=
tfgan
.
stargan_loss
(
model
)
# Define the train ops.
with
tf
.
name_scope
(
'train_ops'
):
train_ops
=
_define_train_ops
(
model
,
loss
)
# Define the train steps.
train_steps
=
_define_train_step
()
# Define a status message.
status_message
=
tf
.
string_join
(
[
'Starting train step: '
,
tf
.
as_string
(
tf
.
train
.
get_or_create_global_step
())
],
name
=
'status_message'
)
# Train the model.
tfgan
.
gan_train
(
train_ops
,
FLAGS
.
train_log_dir
,
get_hooks_fn
=
tfgan
.
get_sequential_train_hooks
(
train_steps
),
hooks
=
[
tf
.
train
.
StopAtStepHook
(
num_steps
=
FLAGS
.
max_number_of_steps
),
tf
.
train
.
LoggingTensorHook
([
status_message
],
every_n_iter
=
10
)
],
master
=
FLAGS
.
master
,
is_chief
=
FLAGS
.
task
==
0
)
if
__name__
==
'__main__'
:
tf
.
flags
.
mark_flag_as_required
(
'image_file_patterns'
)
tf
.
app
.
run
()
research/gan/stargan/train_test.py
0 → 100644
View file @
54a5a577
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stargan.train."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
import
train
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
tfgan
=
tf
.
contrib
.
gan
def
_test_generator
(
input_images
,
_
):
"""Simple generator function."""
return
input_images
*
tf
.
get_variable
(
'dummy_g'
,
initializer
=
2.0
)
def
_test_discriminator
(
inputs
,
num_domains
):
"""Differentiable dummy discriminator for StarGAN."""
hidden
=
tf
.
contrib
.
layers
.
flatten
(
inputs
)
output_src
=
tf
.
reduce_mean
(
hidden
,
axis
=
1
)
output_cls
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
hidden
,
num_outputs
=
num_domains
,
activation_fn
=
None
,
normalizer_fn
=
None
,
biases_initializer
=
None
)
return
output_src
,
output_cls
train
.
network
.
generator
=
_test_generator
train
.
network
.
discriminator
=
_test_discriminator
class
TrainTest
(
tf
.
test
.
TestCase
):
def
test_define_model
(
self
):
FLAGS
.
batch_size
=
2
images_shape
=
[
FLAGS
.
batch_size
,
4
,
4
,
3
]
images_np
=
np
.
zeros
(
shape
=
images_shape
)
images
=
tf
.
constant
(
images_np
,
dtype
=
tf
.
float32
)
labels
=
tf
.
one_hot
([
0
]
*
FLAGS
.
batch_size
,
2
)
model
=
train
.
_define_model
(
images
,
labels
)
self
.
assertIsInstance
(
model
,
tfgan
.
StarGANModel
)
self
.
assertShapeEqual
(
images_np
,
model
.
generated_data
)
self
.
assertShapeEqual
(
images_np
,
model
.
reconstructed_data
)
self
.
assertTrue
(
isinstance
(
model
.
discriminator_variables
,
list
))
self
.
assertTrue
(
isinstance
(
model
.
generator_variables
,
list
))
self
.
assertIsInstance
(
model
.
discriminator_scope
,
tf
.
VariableScope
)
self
.
assertTrue
(
model
.
generator_scope
,
tf
.
VariableScope
)
self
.
assertTrue
(
callable
(
model
.
discriminator_fn
))
self
.
assertTrue
(
callable
(
model
.
generator_fn
))
@
mock
.
patch
.
object
(
tf
.
train
,
'get_or_create_global_step'
,
autospec
=
True
)
def
test_get_lr
(
self
,
mock_get_or_create_global_step
):
FLAGS
.
max_number_of_steps
=
10
base_lr
=
0.01
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
mock_get_or_create_global_step
.
return_value
=
tf
.
constant
(
2
)
lr_step2
=
sess
.
run
(
train
.
_get_lr
(
base_lr
))
mock_get_or_create_global_step
.
return_value
=
tf
.
constant
(
9
)
lr_step9
=
sess
.
run
(
train
.
_get_lr
(
base_lr
))
self
.
assertAlmostEqual
(
base_lr
,
lr_step2
)
self
.
assertAlmostEqual
(
base_lr
*
0.2
,
lr_step9
)
@
mock
.
patch
.
object
(
tf
.
summary
,
'scalar'
,
autospec
=
True
)
def
test_define_train_ops
(
self
,
mock_summary_scalar
):
FLAGS
.
batch_size
=
2
FLAGS
.
generator_lr
=
0.1
FLAGS
.
discriminator_lr
=
0.01
images_shape
=
[
FLAGS
.
batch_size
,
4
,
4
,
3
]
images
=
tf
.
zeros
(
images_shape
,
dtype
=
tf
.
float32
)
labels
=
tf
.
one_hot
([
0
]
*
FLAGS
.
batch_size
,
2
)
model
=
train
.
_define_model
(
images
,
labels
)
loss
=
tfgan
.
stargan_loss
(
model
)
train_ops
=
train
.
_define_train_ops
(
model
,
loss
)
self
.
assertIsInstance
(
train_ops
,
tfgan
.
GANTrainOps
)
mock_summary_scalar
.
assert_has_calls
([
mock
.
call
(
'generator_lr'
,
mock
.
ANY
),
mock
.
call
(
'discriminator_lr'
,
mock
.
ANY
)
])
def
test_get_train_step
(
self
):
FLAGS
.
gen_disc_step_ratio
=
0.5
train_steps
=
train
.
_define_train_step
()
self
.
assertEqual
(
1
,
train_steps
.
generator_train_steps
)
self
.
assertEqual
(
2
,
train_steps
.
discriminator_train_steps
)
FLAGS
.
gen_disc_step_ratio
=
3
train_steps
=
train
.
_define_train_step
()
self
.
assertEqual
(
3
,
train_steps
.
generator_train_steps
)
self
.
assertEqual
(
1
,
train_steps
.
discriminator_train_steps
)
@
mock
.
patch
.
object
(
train
.
data_provider
,
'provide_data'
,
autospec
=
True
)
def
test_main
(
self
,
mock_provide_data
):
FLAGS
.
image_file_patterns
=
[
'/tmp/A/*.jpg'
,
'/tmp/B/*.jpg'
,
'/tmp/C/*.jpg'
]
FLAGS
.
max_number_of_steps
=
10
FLAGS
.
batch_size
=
2
num_domains
=
3
images_shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
patch_size
,
FLAGS
.
patch_size
,
3
]
img_list
=
[
tf
.
zeros
(
images_shape
)]
*
num_domains
lbl_list
=
[
tf
.
one_hot
([
0
]
*
FLAGS
.
batch_size
,
num_domains
)]
*
num_domains
mock_provide_data
.
return_value
=
(
img_list
,
lbl_list
)
train
.
main
(
None
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/stargan_estimator/data/celeba_test_split_images.npy
0 → 100644
View file @
54a5a577
File added
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment