Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d5fc3ef0
Commit
d5fc3ef0
authored
Apr 04, 2018
by
pkulzc
Browse files
Merge remote-tracking branch 'upstream/master'
parents
6b72b5cd
57b99319
Changes
52
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1006 additions
and
5 deletions
+1006
-5
research/differential_privacy/pate/smooth_sensitivity_test.py
...arch/differential_privacy/pate/smooth_sensitivity_test.py
+126
-0
research/gan/cyclegan/data_provider.py
research/gan/cyclegan/data_provider.py
+150
-0
research/gan/cyclegan/data_provider_test.py
research/gan/cyclegan/data_provider_test.py
+101
-0
research/gan/cyclegan/inference_demo.py
research/gan/cyclegan/inference_demo.py
+150
-0
research/gan/cyclegan/inference_demo_test.py
research/gan/cyclegan/inference_demo_test.py
+99
-0
research/gan/cyclegan/testdata/00500.jpg
research/gan/cyclegan/testdata/00500.jpg
+0
-0
research/gan/cyclegan/train.py
research/gan/cyclegan/train.py
+218
-0
research/gan/cyclegan/train_test.py
research/gan/cyclegan/train_test.py
+156
-0
research/gan/pix2pix/networks_test.py
research/gan/pix2pix/networks_test.py
+1
-1
research/gan/pix2pix/train.py
research/gan/pix2pix/train.py
+1
-1
research/gan/pix2pix/train_test.py
research/gan/pix2pix/train_test.py
+1
-1
samples/outreach/demos/eager_execution.ipynb
samples/outreach/demos/eager_execution.ipynb
+3
-2
No files found.
research/differential_privacy/pate/smooth_sensitivity_test.py
0 → 100644
View file @
d5fc3ef0
# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for google3.experimental.brain.privacy.pate.pate_smooth_sensitivity."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
smooth_sensitivity
as
pate_ss
class
PateSmoothSensitivityTest
(
unittest
.
TestCase
):
def
test_check_conditions
(
self
):
self
.
assertEqual
(
pate_ss
.
check_conditions
(
20
,
10
,
25.
),
(
True
,
False
))
self
.
assertEqual
(
pate_ss
.
check_conditions
(
30
,
10
,
25.
),
(
True
,
True
))
def
_assert_all_close
(
self
,
x
,
y
):
"""Asserts that two numpy arrays are close."""
self
.
assertEqual
(
len
(
x
),
len
(
y
))
self
.
assertTrue
(
np
.
allclose
(
x
,
y
,
rtol
=
1e-8
,
atol
=
0
))
def
test_compute_local_sensitivity_bounds_gnmax
(
self
):
counts1
=
np
.
array
([
10
,
0
,
0
])
sigma1
=
.
5
order1
=
1.5
answer1
=
np
.
array
(
[
3.13503646e-17
,
1.60178280e-08
,
5.90681786e-03
]
+
[
5.99981308e+00
]
*
7
)
# Test for "going right" in the smooth sensitivity computation.
out1
=
pate_ss
.
compute_local_sensitivity_bounds_gnmax
(
counts1
,
10
,
sigma1
,
order1
)
self
.
_assert_all_close
(
out1
,
answer1
)
counts2
=
np
.
array
([
1000
,
500
,
300
,
200
,
0
])
sigma2
=
250.
order2
=
10.
# Test for "going left" in the smooth sensitivity computation.
out2
=
pate_ss
.
compute_local_sensitivity_bounds_gnmax
(
counts2
,
2000
,
sigma2
,
order2
)
answer2
=
np
.
array
([
0.
]
*
298
+
[
2.77693450548e-7
,
2.10853979548e-6
]
+
[
2.73113623988e-6
]
*
1700
)
self
.
_assert_all_close
(
out2
,
answer2
)
def
test_compute_local_sensitivity_bounds_threshold
(
self
):
counts1_3
=
np
.
array
([
20
,
10
,
0
])
num_teachers
=
sum
(
counts1_3
)
t1
=
16
# high threshold
sigma
=
2
order
=
10
out1
=
pate_ss
.
compute_local_sensitivity_bounds_threshold
(
counts1_3
,
num_teachers
,
t1
,
sigma
,
order
)
answer1
=
np
.
array
([
0
]
*
3
+
[
1.48454129e-04
,
1.47826870e-02
,
3.94153241e-02
,
6.45775697e-02
,
9.01543247e-02
,
1.16054002e-01
,
1.42180452e-01
,
1.42180452e-01
,
1.48454129e-04
,
1.47826870e-02
,
3.94153241e-02
,
6.45775697e-02
,
9.01543266e-02
,
1.16054000e-01
,
1.42180452e-01
,
1.68302106e-01
,
1.93127860e-01
]
+
[
0
]
*
10
)
self
.
_assert_all_close
(
out1
,
answer1
)
t2
=
2
# low threshold
out2
=
pate_ss
.
compute_local_sensitivity_bounds_threshold
(
counts1_3
,
num_teachers
,
t2
,
sigma
,
order
)
answer2
=
np
.
array
([
1.60212079e-01
,
2.07021132e-01
,
2.07021132e-01
,
1.93127860e-01
,
1.68302106e-01
,
1.42180452e-01
,
1.16054002e-01
,
9.01543247e-02
,
6.45775697e-02
,
3.94153241e-02
,
1.47826870e-02
,
1.48454129e-04
]
+
[
0
]
*
18
)
self
.
_assert_all_close
(
out2
,
answer2
)
t3
=
50
# very high threshold (larger than the number of teachers).
out3
=
pate_ss
.
compute_local_sensitivity_bounds_threshold
(
counts1_3
,
num_teachers
,
t3
,
sigma
,
order
)
answer3
=
np
.
array
([
1.35750725752e-19
,
1.88990500499e-17
,
2.05403154065e-15
,
1.74298153642e-13
,
1.15489723995e-11
,
5.97584949325e-10
,
2.41486826748e-08
,
7.62150641922e-07
,
1.87846248741e-05
,
0.000360973025976
,
0.000360973025976
,
2.76377015215e-50
,
1.00904975276e-53
,
2.87254164748e-57
,
6.37583360761e-61
,
1.10331620211e-64
,
1.48844393335e-68
,
1.56535552444e-72
,
1.28328011060e-76
,
8.20047697109e-81
]
+
[
0
]
*
10
)
self
.
_assert_all_close
(
out3
,
answer3
)
# Fractional values.
counts4
=
np
.
array
([
19.5
,
-
5.1
,
0
])
t4
=
10.1
out4
=
pate_ss
.
compute_local_sensitivity_bounds_threshold
(
counts4
,
num_teachers
,
t4
,
sigma
,
order
)
answer4
=
np
.
array
([
0.0620410301
,
0.0875807131
,
0.113451958
,
0.139561671
,
0.1657074530
,
0.1908244840
,
0.2070270720
,
0.207027072
,
0.169718100
,
0.0575152142
,
0.00678695871
]
+
[
0
]
*
6
+
[
0.000536304908
,
0.0172181073
,
0.041909870
]
+
[
0
]
*
10
)
self
.
_assert_all_close
(
out4
,
answer4
)
if
__name__
==
"__main__"
:
unittest
.
main
()
research/gan/cyclegan/data_provider.py
0 → 100644
View file @
d5fc3ef0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing image data."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
def
normalize_image
(
image
):
"""Rescale from range [0, 255] to [-1, 1]."""
return
(
tf
.
to_float
(
image
)
-
127.5
)
/
127.5
def
undo_normalize_image
(
normalized_image
):
"""Convert to a numpy array that can be read by PIL."""
# Convert from NHWC to HWC.
normalized_image
=
np
.
squeeze
(
normalized_image
,
axis
=
0
)
return
np
.
uint8
(
normalized_image
*
127.5
+
127.5
)
def
_sample_patch
(
image
,
patch_size
):
"""Crop image to square shape and resize it to `patch_size`.
Args:
image: A 3D `Tensor` of HWC format.
patch_size: A Python scalar. The output image size.
Returns:
A 3D `Tensor` of HWC format which has the shape of
[patch_size, patch_size, 3].
"""
image_shape
=
tf
.
shape
(
image
)
height
,
width
=
image_shape
[
0
],
image_shape
[
1
]
target_size
=
tf
.
minimum
(
height
,
width
)
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
target_size
,
target_size
)
# tf.image.resize_area only accepts 4D tensor, so expand dims first.
image
=
tf
.
expand_dims
(
image
,
axis
=
0
)
image
=
tf
.
image
.
resize_images
(
image
,
[
patch_size
,
patch_size
])
image
=
tf
.
squeeze
(
image
,
axis
=
0
)
# Force image num_channels = 3
image
=
tf
.
tile
(
image
,
[
1
,
1
,
tf
.
maximum
(
1
,
4
-
tf
.
shape
(
image
)[
2
])])
image
=
tf
.
slice
(
image
,
[
0
,
0
,
0
],
[
patch_size
,
patch_size
,
3
])
return
image
def
full_image_to_patch
(
image
,
patch_size
):
image
=
normalize_image
(
image
)
# Sample a patch of fixed size.
image_patch
=
_sample_patch
(
image
,
patch_size
)
image_patch
.
shape
.
assert_is_compatible_with
([
patch_size
,
patch_size
,
3
])
return
image_patch
def
_provide_custom_dataset
(
image_file_pattern
,
batch_size
,
shuffle
=
True
,
num_threads
=
1
,
patch_size
=
128
):
"""Provides batches of custom image data.
Args:
image_file_pattern: A string of glob pattern of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the path to extract from the image. Defaults to 128.
Returns:
A float `Tensor` of shape [batch_size, patch_size, patch_size, 3]
representing a batch of images.
"""
filename_queue
=
tf
.
train
.
string_input_producer
(
tf
.
train
.
match_filenames_once
(
image_file_pattern
),
shuffle
=
shuffle
,
capacity
=
5
*
batch_size
)
image_reader
=
tf
.
WholeFileReader
()
_
,
image_bytes
=
image_reader
.
read
(
filename_queue
)
image
=
tf
.
image
.
decode_image
(
image_bytes
)
image_patch
=
full_image_to_patch
(
image
,
patch_size
)
if
shuffle
:
return
tf
.
train
.
shuffle_batch
(
[
image_patch
],
batch_size
=
batch_size
,
num_threads
=
num_threads
,
capacity
=
5
*
batch_size
,
min_after_dequeue
=
batch_size
)
else
:
return
tf
.
train
.
batch
(
[
image_patch
],
batch_size
=
batch_size
,
num_threads
=
1
,
# no threads so it's deterministic
capacity
=
5
*
batch_size
)
def
provide_custom_datasets
(
image_file_patterns
,
batch_size
,
shuffle
=
True
,
num_threads
=
1
,
patch_size
=
128
):
"""Provides multiple batches of custom image data.
Args:
image_file_patterns: A list of glob patterns of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of float `Tensor`s with the same size of `image_file_patterns`.
Each of the `Tensor` in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images.
Raises:
ValueError: If image_file_patterns is not a list or tuple.
"""
if
not
isinstance
(
image_file_patterns
,
(
list
,
tuple
)):
raise
ValueError
(
'`image_file_patterns` should be either list or tuple, but was {}.'
.
format
(
type
(
image_file_patterns
)))
custom_datasets
=
[]
for
pattern
in
image_file_patterns
:
custom_datasets
.
append
(
_provide_custom_dataset
(
pattern
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
num_threads
=
num_threads
,
patch_size
=
patch_size
))
return
custom_datasets
research/gan/cyclegan/data_provider_test.py
0 → 100644
View file @
d5fc3ef0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
tensorflow
as
tf
import
data_provider
mock
=
tf
.
test
.
mock
class
DataProviderTest
(
tf
.
test
.
TestCase
):
def
test_normalize_image
(
self
):
image
=
tf
.
random_uniform
(
shape
=
(
8
,
8
,
3
),
maxval
=
256
,
dtype
=
tf
.
int32
)
rescaled_image
=
data_provider
.
normalize_image
(
image
)
self
.
assertEqual
(
tf
.
float32
,
rescaled_image
.
dtype
)
self
.
assertListEqual
(
image
.
shape
.
as_list
(),
rescaled_image
.
shape
.
as_list
())
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
rescaled_image_out
=
sess
.
run
(
rescaled_image
)
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
rescaled_image_out
)
<=
1.0
))
def
test_sample_patch
(
self
):
image
=
tf
.
zeros
(
shape
=
(
8
,
8
,
3
))
patch1
=
data_provider
.
_sample_patch
(
image
,
7
)
patch2
=
data_provider
.
_sample_patch
(
image
,
10
)
image
=
tf
.
zeros
(
shape
=
(
8
,
8
,
1
))
patch3
=
data_provider
.
_sample_patch
(
image
,
10
)
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
self
.
assertTupleEqual
((
7
,
7
,
3
),
sess
.
run
(
patch1
).
shape
)
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch2
).
shape
)
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch3
).
shape
)
def
_get_testdata_dir
(
self
):
return
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata'
)
def
test_custom_dataset_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
_get_testdata_dir
(),
'*.jpg'
)
batch_size
=
3
patch_size
=
8
images
=
data_provider
.
_provide_custom_dataset
(
file_pattern
,
batch_size
=
batch_size
,
patch_size
=
patch_size
)
self
.
assertListEqual
([
batch_size
,
patch_size
,
patch_size
,
3
],
images
.
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
images
.
dtype
)
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
sess
.
run
(
tf
.
local_variables_initializer
())
with
tf
.
contrib
.
slim
.
queues
.
QueueRunners
(
sess
):
images_out
=
sess
.
run
(
images
)
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_custom_datasets_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
_get_testdata_dir
(),
'*.jpg'
)
batch_size
=
3
patch_size
=
8
images_list
=
data_provider
.
provide_custom_datasets
(
[
file_pattern
,
file_pattern
],
batch_size
=
batch_size
,
patch_size
=
patch_size
)
for
images
in
images_list
:
self
.
assertListEqual
([
batch_size
,
patch_size
,
patch_size
,
3
],
images
.
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
images
.
dtype
)
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
sess
.
run
(
tf
.
local_variables_initializer
())
with
tf
.
contrib
.
slim
.
queues
.
QueueRunners
(
sess
):
images_out_list
=
sess
.
run
(
images_list
)
for
images_out
in
images_out_list
:
self
.
assertTupleEqual
((
batch_size
,
patch_size
,
patch_size
,
3
),
images_out
.
shape
)
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/cyclegan/inference_demo.py
0 → 100644
View file @
d5fc3ef0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Demo that makes inference requests against a running inference server."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
PIL
import
tensorflow
as
tf
import
data_provider
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_string
(
'checkpoint_path'
,
''
,
'CycleGAN checkpoint path created by train.py. '
'(e.g. "/mylogdir/model.ckpt-18442")'
)
flags
.
DEFINE_string
(
'image_set_x_glob'
,
''
,
'Optional: Glob path to images of class X to feed through the CycleGAN.'
)
flags
.
DEFINE_string
(
'image_set_y_glob'
,
''
,
'Optional: Glob path to images of class Y to feed through the CycleGAN.'
)
flags
.
DEFINE_string
(
'generated_x_dir'
,
'/tmp/generated_x/'
,
'If image_set_y_glob is defined, where to output the generated X '
'images.'
)
flags
.
DEFINE_string
(
'generated_y_dir'
,
'/tmp/generated_y/'
,
'If image_set_x_glob is defined, where to output the generated Y '
'images.'
)
flags
.
DEFINE_integer
(
'patch_dim'
,
128
,
'The patch size of images that was used in train.py.'
)
FLAGS
=
flags
.
FLAGS
def
_make_dir_if_not_exists
(
dir_path
):
"""Make a directory if it does not exist."""
if
not
tf
.
gfile
.
Exists
(
dir_path
):
tf
.
gfile
.
MakeDirs
(
dir_path
)
def
_file_output_path
(
dir_path
,
input_file_path
):
"""Create output path for an individual file."""
return
os
.
path
.
join
(
dir_path
,
os
.
path
.
basename
(
input_file_path
))
def
make_inference_graph
(
model_name
,
patch_dim
):
"""Build the inference graph for either the X2Y or Y2X GAN.
Args:
model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
patch_dim: An integer size of patches to feed to the generator.
Returns:
Tuple of (input_placeholder, generated_tensor).
"""
input_hwc_pl
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
None
,
3
])
# Expand HWC to NHWC
images_x
=
tf
.
expand_dims
(
data_provider
.
full_image_to_patch
(
input_hwc_pl
,
patch_dim
),
0
)
with
tf
.
variable_scope
(
model_name
):
with
tf
.
variable_scope
(
'Generator'
):
generated
=
networks
.
generator
(
images_x
)
return
input_hwc_pl
,
generated
def
export
(
sess
,
input_pl
,
output_tensor
,
input_file_pattern
,
output_dir
):
"""Exports inference outputs to an output directory.
Args:
sess: tf.Session with variables already loaded.
input_pl: tf.Placeholder for input (HWC format).
output_tensor: Tensor for generated outut images.
input_file_pattern: Glob file pattern for input images.
output_dir: Output directory.
"""
if
output_dir
:
_make_dir_if_not_exists
(
output_dir
)
if
input_file_pattern
:
for
file_path
in
tf
.
gfile
.
Glob
(
input_file_pattern
):
# Grab a single image and run it through inference
input_np
=
np
.
asarray
(
PIL
.
Image
.
open
(
file_path
))
output_np
=
sess
.
run
(
output_tensor
,
feed_dict
=
{
input_pl
:
input_np
})
image_np
=
data_provider
.
undo_normalize_image
(
output_np
)
output_path
=
_file_output_path
(
output_dir
,
file_path
)
PIL
.
Image
.
fromarray
(
image_np
).
save
(
output_path
)
def
_validate_flags
():
flags
.
register_validator
(
'checkpoint_path'
,
bool
,
'Must provide `checkpoint_path`.'
)
flags
.
register_validator
(
'generated_x_dir'
,
lambda
x
:
False
if
(
FLAGS
.
image_set_y_glob
and
not
x
)
else
True
,
'Must provide `generated_x_dir`.'
)
flags
.
register_validator
(
'generated_y_dir'
,
lambda
x
:
False
if
(
FLAGS
.
image_set_x_glob
and
not
x
)
else
True
,
'Must provide `generated_y_dir`.'
)
def
main
(
_
):
_validate_flags
()
images_x_hwc_pl
,
generated_y
=
make_inference_graph
(
'ModelX2Y'
,
FLAGS
.
patch_dim
)
images_y_hwc_pl
,
generated_x
=
make_inference_graph
(
'ModelY2X'
,
FLAGS
.
patch_dim
)
# Restore all the variables that were saved in the checkpoint.
saver
=
tf
.
train
.
Saver
()
with
tf
.
Session
()
as
sess
:
saver
.
restore
(
sess
,
FLAGS
.
checkpoint_path
)
export
(
sess
,
images_x_hwc_pl
,
generated_y
,
FLAGS
.
image_set_x_glob
,
FLAGS
.
generated_y_dir
)
export
(
sess
,
images_y_hwc_pl
,
generated_x
,
FLAGS
.
image_set_y_glob
,
FLAGS
.
generated_x_dir
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/gan/cyclegan/inference_demo_test.py
0 → 100644
View file @
d5fc3ef0
"""Tests for CycleGAN inference demo."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
PIL
import
tensorflow
as
tf
import
inference_demo
import
train
FLAGS
=
tf
.
flags
.
FLAGS
mock
=
tf
.
test
.
mock
tfgan
=
tf
.
contrib
.
gan
def
_basenames_from_glob
(
file_glob
):
return
[
os
.
path
.
basename
(
file_path
)
for
file_path
in
tf
.
gfile
.
Glob
(
file_glob
)]
class
InferenceDemoTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
self
.
_export_dir
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'export'
)
self
.
_ckpt_path
=
os
.
path
.
join
(
self
.
_export_dir
,
'model.ckpt'
)
self
.
_image_glob
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata'
,
'*.jpg'
)
self
.
_genx_dir
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'genx'
)
self
.
_geny_dir
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'geny'
)
@
mock
.
patch
.
object
(
tfgan
,
'gan_train'
,
autospec
=
True
)
def
testTrainingAndInferenceGraphsAreCompatible
(
self
,
unused_mock_gan_train
):
# Training and inference graphs can get out of sync if changes are made
# to one but not the other. This test will keep them in sync.
# Save the training graph
train_sess
=
tf
.
Session
()
FLAGS
.
image_set_x_file_pattern
=
'/tmp/x/*.jpg'
FLAGS
.
image_set_y_file_pattern
=
'/tmp/y/*.jpg'
FLAGS
.
batch_size
=
3
FLAGS
.
patch_size
=
128
FLAGS
.
generator_lr
=
0.02
FLAGS
.
discriminator_lr
=
0.3
FLAGS
.
train_log_dir
=
self
.
_export_dir
FLAGS
.
master
=
'master'
FLAGS
.
task
=
0
FLAGS
.
cycle_consistency_loss_weight
=
2.0
FLAGS
.
max_number_of_steps
=
1
train
.
main
(
None
)
init_op
=
tf
.
global_variables_initializer
()
train_sess
.
run
(
init_op
)
train_saver
=
tf
.
train
.
Saver
()
train_saver
.
save
(
train_sess
,
save_path
=
self
.
_ckpt_path
)
# Create inference graph
tf
.
reset_default_graph
()
FLAGS
.
patch_dim
=
FLAGS
.
patch_size
tf
.
logging
.
info
(
'dir_path: {}'
.
format
(
os
.
listdir
(
self
.
_export_dir
)))
FLAGS
.
checkpoint_path
=
self
.
_ckpt_path
FLAGS
.
image_set_x_glob
=
self
.
_image_glob
FLAGS
.
image_set_y_glob
=
self
.
_image_glob
FLAGS
.
generated_x_dir
=
self
.
_genx_dir
FLAGS
.
generated_y_dir
=
self
.
_geny_dir
inference_demo
.
main
(
None
)
tf
.
logging
.
info
(
'gen x: {}'
.
format
(
os
.
listdir
(
self
.
_genx_dir
)))
# Check that the image names match
self
.
assertSetEqual
(
set
(
_basenames_from_glob
(
FLAGS
.
image_set_x_glob
)),
set
(
os
.
listdir
(
FLAGS
.
generated_y_dir
)))
self
.
assertSetEqual
(
set
(
_basenames_from_glob
(
FLAGS
.
image_set_y_glob
)),
set
(
os
.
listdir
(
FLAGS
.
generated_x_dir
)))
# Check that each image in the directory looks as expected
for
directory
in
[
FLAGS
.
generated_x_dir
,
FLAGS
.
generated_x_dir
]:
for
base_name
in
os
.
listdir
(
directory
):
image_path
=
os
.
path
.
join
(
directory
,
base_name
)
self
.
assertRealisticImage
(
image_path
)
def
assertRealisticImage
(
self
,
image_path
):
tf
.
logging
.
info
(
'Testing {} for realism.'
.
format
(
image_path
))
# If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved.
input_np
=
np
.
asarray
(
PIL
.
Image
.
open
(
image_path
))
self
.
assertEqual
(
len
(
input_np
.
shape
),
3
)
self
.
assertGreaterEqual
(
input_np
.
shape
[
0
],
50
)
self
.
assertGreaterEqual
(
input_np
.
shape
[
1
],
50
)
self
.
assertGreater
(
np
.
mean
(
input_np
),
20
)
self
.
assertGreater
(
np
.
var
(
input_np
),
100
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/cyclegan/testdata/00500.jpg
0 → 100644
View file @
d5fc3ef0
12.9 KB
research/gan/cyclegan/train.py
0 → 100644
View file @
d5fc3ef0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a CycleGAN model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
data_provider
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_string
(
'image_set_x_file_pattern'
,
None
,
'File pattern of images in image set X'
)
flags
.
DEFINE_string
(
'image_set_y_file_pattern'
,
None
,
'File pattern of images in image set Y'
)
flags
.
DEFINE_integer
(
'batch_size'
,
1
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'patch_size'
,
64
,
'The patch size of images.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'train_log_dir'
,
'/tmp/cyclegan/'
,
'Directory where to write event logs.'
)
flags
.
DEFINE_float
(
'generator_lr'
,
0.0002
,
'The compression model learning rate.'
)
flags
.
DEFINE_float
(
'discriminator_lr'
,
0.0001
,
'The discriminator learning rate.'
)
flags
.
DEFINE_integer
(
'max_number_of_steps'
,
500000
,
'The maximum number of gradient steps.'
)
flags
.
DEFINE_integer
(
'ps_tasks'
,
0
,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.'
)
flags
.
DEFINE_integer
(
'task'
,
0
,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.'
)
flags
.
DEFINE_float
(
'cycle_consistency_loss_weight'
,
10.0
,
'The weight of cycle consistency loss'
)
FLAGS
=
flags
.
FLAGS
def
_define_model
(
images_x
,
images_y
):
"""Defines a CycleGAN model that maps between images_x and images_y.
Args:
images_x: A 4D float `Tensor` of NHWC format. Images in set X.
images_y: A 4D float `Tensor` of NHWC format. Images in set Y.
Returns:
A `CycleGANModel` namedtuple.
"""
cyclegan_model
=
tfgan
.
cyclegan_model
(
generator_fn
=
networks
.
generator
,
discriminator_fn
=
networks
.
discriminator
,
data_x
=
images_x
,
data_y
=
images_y
)
# Add summaries for generated images.
tfgan
.
eval
.
add_image_comparison_summaries
(
cyclegan_model
,
num_comparisons
=
3
,
display_diffs
=
False
)
tfgan
.
eval
.
add_gan_model_image_summaries
(
cyclegan_model
,
grid_size
=
int
(
np
.
sqrt
(
FLAGS
.
batch_size
)))
return
cyclegan_model
def
_get_lr
(
base_lr
):
"""Returns a learning rate `Tensor`.
Args:
base_lr: A scalar float `Tensor` or a Python number. The base learning
rate.
Returns:
A scalar float `Tensor` of learning rate which equals `base_lr` when the
global training step is less than FLAGS.max_number_of_steps / 2, afterwards
it linearly decays to zero.
"""
global_step
=
tf
.
train
.
get_or_create_global_step
()
lr_constant_steps
=
FLAGS
.
max_number_of_steps
//
2
def
_lr_decay
():
return
tf
.
train
.
polynomial_decay
(
learning_rate
=
base_lr
,
global_step
=
(
global_step
-
lr_constant_steps
),
decay_steps
=
(
FLAGS
.
max_number_of_steps
-
lr_constant_steps
),
end_learning_rate
=
0.0
)
return
tf
.
cond
(
global_step
<
lr_constant_steps
,
lambda
:
base_lr
,
_lr_decay
)
def
_get_optimizer
(
gen_lr
,
dis_lr
):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
# beta1 follows
# https://github.com/junyanz/CycleGAN/blob/master/options.lua
gen_opt
=
tf
.
train
.
AdamOptimizer
(
gen_lr
,
beta1
=
0.5
,
use_locking
=
True
)
dis_opt
=
tf
.
train
.
AdamOptimizer
(
dis_lr
,
beta1
=
0.5
,
use_locking
=
True
)
return
gen_opt
,
dis_opt
def
_define_train_ops
(
cyclegan_model
,
cyclegan_loss
):
"""Defines train ops that trains `cyclegan_model` with `cyclegan_loss`.
Args:
cyclegan_model: A `CycleGANModel` namedtuple.
cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for
`cyclegan_model`.
Returns:
A `GANTrainOps` namedtuple.
"""
gen_lr
=
_get_lr
(
FLAGS
.
generator_lr
)
dis_lr
=
_get_lr
(
FLAGS
.
discriminator_lr
)
gen_opt
,
dis_opt
=
_get_optimizer
(
gen_lr
,
dis_lr
)
train_ops
=
tfgan
.
gan_train_ops
(
cyclegan_model
,
cyclegan_loss
,
generator_optimizer
=
gen_opt
,
discriminator_optimizer
=
dis_opt
,
summarize_gradients
=
True
,
colocate_gradients_with_ops
=
True
,
aggregation_method
=
tf
.
AggregationMethod
.
EXPERIMENTAL_ACCUMULATE_N
)
tf
.
summary
.
scalar
(
'generator_lr'
,
gen_lr
)
tf
.
summary
.
scalar
(
'discriminator_lr'
,
dis_lr
)
return
train_ops
def
main
(
_
):
if
not
tf
.
gfile
.
Exists
(
FLAGS
.
train_log_dir
):
tf
.
gfile
.
MakeDirs
(
FLAGS
.
train_log_dir
)
with
tf
.
device
(
tf
.
train
.
replica_device_setter
(
FLAGS
.
ps_tasks
)):
with
tf
.
name_scope
(
'inputs'
):
images_x
,
images_y
=
data_provider
.
provide_custom_datasets
(
[
FLAGS
.
image_set_x_file_pattern
,
FLAGS
.
image_set_y_file_pattern
],
batch_size
=
FLAGS
.
batch_size
,
patch_size
=
FLAGS
.
patch_size
)
# Define CycleGAN model.
cyclegan_model
=
_define_model
(
images_x
,
images_y
)
# Define CycleGAN loss.
cyclegan_loss
=
tfgan
.
cyclegan_loss
(
cyclegan_model
,
cycle_consistency_loss_weight
=
FLAGS
.
cycle_consistency_loss_weight
,
tensor_pool_fn
=
tfgan
.
features
.
tensor_pool
)
# Define CycleGAN train ops.
train_ops
=
_define_train_ops
(
cyclegan_model
,
cyclegan_loss
)
# Training
train_steps
=
tfgan
.
GANTrainSteps
(
1
,
1
)
status_message
=
tf
.
string_join
(
[
'Starting train step: '
,
tf
.
as_string
(
tf
.
train
.
get_or_create_global_step
())
],
name
=
'status_message'
)
if
not
FLAGS
.
max_number_of_steps
:
return
tfgan
.
gan_train
(
train_ops
,
FLAGS
.
train_log_dir
,
get_hooks_fn
=
tfgan
.
get_sequential_train_hooks
(
train_steps
),
hooks
=
[
tf
.
train
.
StopAtStepHook
(
num_steps
=
FLAGS
.
max_number_of_steps
),
tf
.
train
.
LoggingTensorHook
([
status_message
],
every_n_iter
=
10
)
],
master
=
FLAGS
.
master
,
is_chief
=
FLAGS
.
task
==
0
)
if
__name__
==
'__main__'
:
tf
.
flags
.
mark_flag_as_required
(
'image_set_x_file_pattern'
)
tf
.
flags
.
mark_flag_as_required
(
'image_set_y_file_pattern'
)
tf
.
app
.
run
()
research/gan/cyclegan/train_test.py
0 → 100644
View file @
d5fc3ef0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cyclegan.train."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
train
FLAGS
=
tf
.
flags
.
FLAGS
mock
=
tf
.
test
.
mock
tfgan
=
tf
.
contrib
.
gan
def
_test_generator
(
input_images
):
"""Simple generator function."""
return
input_images
*
tf
.
get_variable
(
'dummy_g'
,
initializer
=
2.0
)
def
_test_discriminator
(
image_batch
,
unused_conditioning
=
None
):
"""Simple discriminator function."""
return
tf
.
contrib
.
layers
.
flatten
(
image_batch
*
tf
.
get_variable
(
'dummy_d'
,
initializer
=
2.0
))
train
.
networks
.
generator
=
_test_generator
train
.
networks
.
discriminator
=
_test_discriminator
class
TrainTest
(
tf
.
test
.
TestCase
):
@
mock
.
patch
.
object
(
tfgan
,
'eval'
,
autospec
=
True
)
def
test_define_model
(
self
,
mock_eval
):
FLAGS
.
batch_size
=
2
images_shape
=
[
FLAGS
.
batch_size
,
4
,
4
,
3
]
images_x_np
=
np
.
zeros
(
shape
=
images_shape
)
images_y_np
=
np
.
zeros
(
shape
=
images_shape
)
images_x
=
tf
.
constant
(
images_x_np
,
dtype
=
tf
.
float32
)
images_y
=
tf
.
constant
(
images_y_np
,
dtype
=
tf
.
float32
)
cyclegan_model
=
train
.
_define_model
(
images_x
,
images_y
)
self
.
assertIsInstance
(
cyclegan_model
,
tfgan
.
CycleGANModel
)
self
.
assertShapeEqual
(
images_x_np
,
cyclegan_model
.
reconstructed_x
)
self
.
assertShapeEqual
(
images_y_np
,
cyclegan_model
.
reconstructed_y
)
mock_eval
.
add_image_comparison_summaries
.
assert_called_once
()
mock_eval
.
add_gan_model_image_summaries
.
assert_called_once
()
@
mock
.
patch
.
object
(
train
.
networks
,
'generator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
.
networks
,
'discriminator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
tf
.
train
,
'get_or_create_global_step'
,
autospec
=
True
)
def
test_get_lr
(
self
,
mock_get_or_create_global_step
,
unused_mock_discriminator
,
unused_mock_generator
):
FLAGS
.
max_number_of_steps
=
10
base_lr
=
0.01
with
self
.
test_session
(
use_gpu
=
True
)
as
sess
:
mock_get_or_create_global_step
.
return_value
=
tf
.
constant
(
2
)
lr_step2
=
sess
.
run
(
train
.
_get_lr
(
base_lr
))
mock_get_or_create_global_step
.
return_value
=
tf
.
constant
(
9
)
lr_step9
=
sess
.
run
(
train
.
_get_lr
(
base_lr
))
self
.
assertAlmostEqual
(
base_lr
,
lr_step2
)
self
.
assertAlmostEqual
(
base_lr
*
0.2
,
lr_step9
)
@
mock
.
patch
.
object
(
tf
.
train
,
'AdamOptimizer'
,
autospec
=
True
)
def
test_get_optimizer
(
self
,
mock_adam_optimizer
):
gen_lr
,
dis_lr
=
0.1
,
0.01
train
.
_get_optimizer
(
gen_lr
=
gen_lr
,
dis_lr
=
dis_lr
)
mock_adam_optimizer
.
assert_has_calls
([
mock
.
call
(
gen_lr
,
beta1
=
mock
.
ANY
,
use_locking
=
True
),
mock
.
call
(
dis_lr
,
beta1
=
mock
.
ANY
,
use_locking
=
True
)
])
@
mock
.
patch
.
object
(
tf
.
summary
,
'scalar'
,
autospec
=
True
)
def
test_define_train_ops
(
self
,
mock_summary_scalar
):
FLAGS
.
batch_size
=
2
FLAGS
.
generator_lr
=
0.1
FLAGS
.
discriminator_lr
=
0.01
images_shape
=
[
FLAGS
.
batch_size
,
4
,
4
,
3
]
images_x
=
tf
.
zeros
(
images_shape
,
dtype
=
tf
.
float32
)
images_y
=
tf
.
zeros
(
images_shape
,
dtype
=
tf
.
float32
)
cyclegan_model
=
train
.
_define_model
(
images_x
,
images_y
)
cyclegan_loss
=
tfgan
.
cyclegan_loss
(
cyclegan_model
,
cycle_consistency_loss_weight
=
10.0
)
train_ops
=
train
.
_define_train_ops
(
cyclegan_model
,
cyclegan_loss
)
self
.
assertIsInstance
(
train_ops
,
tfgan
.
GANTrainOps
)
mock_summary_scalar
.
assert_has_calls
([
mock
.
call
(
'generator_lr'
,
mock
.
ANY
),
mock
.
call
(
'discriminator_lr'
,
mock
.
ANY
)
])
@
mock
.
patch
.
object
(
tf
,
'gfile'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
,
'data_provider'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
,
'_define_model'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
tfgan
,
'cyclegan_loss'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
,
'_define_train_ops'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
tfgan
,
'gan_train'
,
autospec
=
True
)
def
test_main
(
self
,
mock_gan_train
,
mock_define_train_ops
,
mock_cyclegan_loss
,
mock_define_model
,
mock_data_provider
,
mock_gfile
):
FLAGS
.
image_set_x_file_pattern
=
'/tmp/x/*.jpg'
FLAGS
.
image_set_y_file_pattern
=
'/tmp/y/*.jpg'
FLAGS
.
batch_size
=
3
FLAGS
.
patch_size
=
8
FLAGS
.
generator_lr
=
0.02
FLAGS
.
discriminator_lr
=
0.3
FLAGS
.
train_log_dir
=
'/tmp/foo'
FLAGS
.
master
=
'master'
FLAGS
.
task
=
0
FLAGS
.
cycle_consistency_loss_weight
=
2.0
FLAGS
.
max_number_of_steps
=
1
mock_data_provider
.
provide_custom_datasets
.
return_value
=
(
tf
.
zeros
(
[
1
,
2
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
1
,
2
],
dtype
=
tf
.
float32
))
train
.
main
(
None
)
mock_data_provider
.
provide_custom_datasets
.
assert_called_once_with
(
[
'/tmp/x/*.jpg'
,
'/tmp/y/*.jpg'
],
batch_size
=
3
,
patch_size
=
8
)
mock_define_model
.
assert_called_once_with
(
mock
.
ANY
,
mock
.
ANY
)
mock_cyclegan_loss
.
assert_called_once_with
(
mock_define_model
.
return_value
,
cycle_consistency_loss_weight
=
2.0
,
tensor_pool_fn
=
mock
.
ANY
)
mock_define_train_ops
.
assert_called_once_with
(
mock_define_model
.
return_value
,
mock_cyclegan_loss
.
return_value
)
mock_gan_train
.
assert_called_once_with
(
mock_define_train_ops
.
return_value
,
'/tmp/foo'
,
get_hooks_fn
=
mock
.
ANY
,
hooks
=
mock
.
ANY
,
master
=
'master'
,
is_chief
=
True
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/pix2pix/networks_test.py
View file @
d5fc3ef0
...
...
@@ -19,7 +19,7 @@ from __future__ import division
from
__future__
import
print_function
import
tensorflow
as
tf
from
google3.third_party.tensorflow_models.gan.pix2pix
import
networks
import
networks
class
Pix2PixTest
(
tf
.
test
.
TestCase
):
...
...
research/gan/pix2pix/train.py
View file @
d5fc3ef0
...
...
@@ -23,7 +23,7 @@ from __future__ import print_function
import
tensorflow
as
tf
import
data_provider
from
google3.third_party.tensorflow_models.gan.pix2pix
import
networks
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
...
...
research/gan/pix2pix/train_test.py
View file @
d5fc3ef0
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
numpy
as
np
import
tensorflow
as
tf
from
google3.third_party.tensorflow_models.gan.pix2pix
import
train
import
train
FLAGS
=
tf
.
flags
.
FLAGS
mock
=
tf
.
test
.
mock
...
...
samples/outreach/demos/eager_execution.ipynb
View file @
d5fc3ef0
...
...
@@ -282,9 +282,10 @@
},
"cell_type": "code",
"source": [
"if tf.test.is_gpu_available()
> 0
:\n",
"if tf.test.is_gpu_available():\n",
" with tf.device(tf.test.gpu_device_name()):\n",
" print(tf.matmul(A, A))"
" B = tf.constant([[2.0, 0.0], [0.0, 3.0]])\n",
" print(tf.matmul(B, B))"
],
"execution_count": 0,
"outputs": []
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment