Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4f7074f6
Commit
4f7074f6
authored
Jun 05, 2018
by
Joel Shor
Committed by
Joel Shor
Jun 05, 2018
Browse files
Project import generated by Copybara.
PiperOrigin-RevId: 199251174
parent
30cf3752
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
94 additions
and
94 deletions
+94
-94
research/gan/cifar/data_provider_test.py
research/gan/cifar/data_provider_test.py
+2
-1
research/gan/cifar/eval.py
research/gan/cifar/eval.py
+4
-3
research/gan/cifar/eval_test.py
research/gan/cifar/eval_test.py
+9
-12
research/gan/cifar/train.py
research/gan/cifar/train.py
+3
-2
research/gan/cifar/train_test.py
research/gan/cifar/train_test.py
+9
-11
research/gan/cyclegan/data_provider_test.py
research/gan/cyclegan/data_provider_test.py
+9
-7
research/gan/cyclegan/inference_demo.py
research/gan/cyclegan/inference_demo.py
+3
-2
research/gan/cyclegan/inference_demo_test.py
research/gan/cyclegan/inference_demo_test.py
+5
-3
research/gan/cyclegan/train.py
research/gan/cyclegan/train.py
+2
-6
research/gan/cyclegan/train_test.py
research/gan/cyclegan/train_test.py
+2
-4
research/gan/image_compression/data_provider_test.py
research/gan/image_compression/data_provider_test.py
+8
-9
research/gan/image_compression/eval.py
research/gan/image_compression/eval.py
+3
-2
research/gan/image_compression/train.py
research/gan/image_compression/train.py
+3
-3
research/gan/image_compression/train_test.py
research/gan/image_compression/train_test.py
+8
-9
research/gan/mnist/conditional_eval.py
research/gan/mnist/conditional_eval.py
+3
-2
research/gan/mnist/conditional_eval_test.py
research/gan/mnist/conditional_eval_test.py
+3
-3
research/gan/mnist/data_provider_test.py
research/gan/mnist/data_provider_test.py
+2
-1
research/gan/mnist/eval.py
research/gan/mnist/eval.py
+3
-2
research/gan/mnist/eval_test.py
research/gan/mnist/eval_test.py
+10
-10
research/gan/mnist/infogan_eval.py
research/gan/mnist/infogan_eval.py
+3
-2
No files found.
research/gan/cifar/data_provider_test.py
View file @
4f7074f6
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
os
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -31,7 +32,7 @@ class DataProviderTest(tf.test.TestCase):
def
test_cifar10_train_set
(
self
):
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cifar/testdata'
)
batch_size
=
4
...
...
research/gan/cifar/eval.py
View file @
4f7074f6
...
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
...
...
@@ -25,8 +27,7 @@ import networks
import
util
flags
=
tf
.
flags
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
...
...
@@ -155,4 +156,4 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/gan/cifar/eval_test.py
View file @
4f7074f6
...
...
@@ -19,16 +19,22 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
flags
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
eval
# pylint:disable=redefined-builtin
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
class
EvalTest
(
tf
.
test
.
TestCase
):
class
EvalTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
eval_real_images
,
conditional_eval
):
@
parameterized
.
named_parameters
(
(
'RealData'
,
True
,
False
),
(
'GeneratedData'
,
False
,
False
),
(
'GeneratedDataConditional'
,
False
,
True
))
def
test_build_graph
(
self
,
eval_real_images
,
conditional_eval
):
FLAGS
.
eval_real_images
=
eval_real_images
FLAGS
.
conditional_eval
=
conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are
...
...
@@ -40,15 +46,6 @@ class EvalTest(tf.test.TestCase):
mock_iscore
.
return_value
=
1.0
eval
.
main
(
None
,
run_eval_loop
=
False
)
def
test_build_graph_realdata
(
self
):
self
.
_test_build_graph_helper
(
True
,
False
)
def
test_build_graph_generateddata
(
self
):
self
.
_test_build_graph_helper
(
False
,
False
)
def
test_build_graph_generateddataconditional
(
self
):
self
.
_test_build_graph_helper
(
False
,
True
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/cifar/train.py
View file @
4f7074f6
...
...
@@ -18,6 +18,8 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
data_provider
...
...
@@ -25,7 +27,6 @@ import networks
tfgan
=
tf
.
contrib
.
gan
flags
=
tf
.
flags
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
...
...
@@ -173,6 +174,6 @@ def _optimizer(gen_lr, dis_lr, use_sync_replicas):
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
tf
.
app
.
run
()
research/gan/cifar/train_test.py
View file @
4f7074f6
...
...
@@ -19,17 +19,23 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
flags
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
train
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
class
TrainTest
(
tf
.
test
.
TestCase
):
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
conditional
,
use_sync_replicas
):
@
parameterized
.
named_parameters
(
(
'Unconditional'
,
False
,
False
),
(
'Conditional'
,
True
,
False
),
(
'SyncReplicas'
,
False
,
True
))
def
test_build_graph_helper
(
self
,
conditional
,
use_sync_replicas
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
conditional
=
conditional
FLAGS
.
use_sync_replicas
=
use_sync_replicas
...
...
@@ -45,14 +51,6 @@ class TrainTest(tf.test.TestCase):
mock_imgs
,
mock_lbls
,
None
,
None
)
train
.
main
(
None
)
def
test_build_graph_unconditional
(
self
):
self
.
_test_build_graph_helper
(
False
,
False
)
def
test_build_graph_conditional
(
self
):
self
.
_test_build_graph_helper
(
True
,
False
)
def
test_build_graph_syncreplicas
(
self
):
self
.
_test_build_graph_helper
(
False
,
True
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/cyclegan/data_provider_test.py
View file @
4f7074f6
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
os
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -31,6 +32,12 @@ mock = tf.test.mock
class
DataProviderTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
DataProviderTest
,
self
).
setUp
()
self
.
testdata_dir
=
os
.
path
.
join
(
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata'
)
def
test_normalize_image
(
self
):
image
=
tf
.
random_uniform
(
shape
=
(
8
,
8
,
3
),
maxval
=
256
,
dtype
=
tf
.
int32
)
rescaled_image
=
data_provider
.
normalize_image
(
image
)
...
...
@@ -51,13 +58,8 @@ class DataProviderTest(tf.test.TestCase):
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch2
).
shape
)
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch3
).
shape
)
def
_get_testdata_dir
(
self
):
return
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata'
)
def
test_custom_dataset_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
_get_
testdata_dir
()
,
'*.jpg'
)
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
patch_size
=
8
images
=
data_provider
.
_provide_custom_dataset
(
...
...
@@ -75,7 +77,7 @@ class DataProviderTest(tf.test.TestCase):
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_custom_datasets_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
_get_
testdata_dir
()
,
'*.jpg'
)
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
patch_size
=
8
images_list
=
data_provider
.
provide_custom_datasets
(
...
...
research/gan/cyclegan/inference_demo.py
View file @
4f7074f6
...
...
@@ -21,6 +21,8 @@ from __future__ import print_function
import
os
from
absl
import
app
from
absl
import
flags
import
numpy
as
np
import
PIL
import
tensorflow
as
tf
...
...
@@ -28,7 +30,6 @@ import tensorflow as tf
import
data_provider
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_string
(
'checkpoint_path'
,
''
,
...
...
@@ -147,4 +148,4 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
()
research/gan/cyclegan/inference_demo_test.py
View file @
4f7074f6
...
...
@@ -5,8 +5,10 @@ from __future__ import division
from
__future__
import
print_function
import
os
from
absl
import
logging
import
numpy
as
np
import
PIL
import
tensorflow
as
tf
import
inference_demo
...
...
@@ -59,7 +61,7 @@ class InferenceDemoTest(tf.test.TestCase):
# Create inference graph
tf
.
reset_default_graph
()
FLAGS
.
patch_dim
=
FLAGS
.
patch_size
tf
.
logging
.
info
(
'dir_path:
{}'
.
format
(
os
.
listdir
(
self
.
_export_dir
))
)
logging
.
info
(
'dir_path:
%s'
,
os
.
listdir
(
self
.
_export_dir
))
FLAGS
.
checkpoint_path
=
self
.
_ckpt_path
FLAGS
.
image_set_x_glob
=
self
.
_image_glob
FLAGS
.
image_set_y_glob
=
self
.
_image_glob
...
...
@@ -67,7 +69,7 @@ class InferenceDemoTest(tf.test.TestCase):
FLAGS
.
generated_y_dir
=
self
.
_geny_dir
inference_demo
.
main
(
None
)
tf
.
logging
.
info
(
'gen x:
{}'
.
format
(
os
.
listdir
(
self
.
_genx_dir
))
)
logging
.
info
(
'gen x:
%s'
,
os
.
listdir
(
self
.
_genx_dir
))
# Check that the image names match
self
.
assertSetEqual
(
...
...
@@ -84,7 +86,7 @@ class InferenceDemoTest(tf.test.TestCase):
self
.
assertRealisticImage
(
image_path
)
def
assertRealisticImage
(
self
,
image_path
):
tf
.
logging
.
info
(
'Testing
{}
for realism.'
.
format
(
image_path
)
)
logging
.
info
(
'Testing
%s
for realism.'
,
image_path
)
# If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved.
input_np
=
np
.
asarray
(
PIL
.
Image
.
open
(
image_path
))
...
...
research/gan/cyclegan/train.py
View file @
4f7074f6
...
...
@@ -19,13 +19,12 @@ from __future__ import division
from
__future__
import
print_function
import
numpy
as
np
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
...
...
@@ -87,10 +86,7 @@ def _define_model(images_x, images_y):
data_y
=
images_y
)
# Add summaries for generated images.
tfgan
.
eval
.
add_image_comparison_summaries
(
cyclegan_model
,
num_comparisons
=
3
,
display_diffs
=
False
)
tfgan
.
eval
.
add_gan_model_image_summaries
(
cyclegan_model
,
grid_size
=
int
(
np
.
sqrt
(
FLAGS
.
batch_size
)))
tfgan
.
eval
.
add_cyclegan_image_summaries
(
cyclegan_model
)
return
cyclegan_model
...
...
research/gan/cyclegan/train_test.py
View file @
4f7074f6
...
...
@@ -19,12 +19,13 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
import
train
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
tfgan
=
tf
.
contrib
.
gan
...
...
@@ -60,9 +61,6 @@ class TrainTest(tf.test.TestCase):
self
.
assertShapeEqual
(
images_x_np
,
cyclegan_model
.
reconstructed_x
)
self
.
assertShapeEqual
(
images_y_np
,
cyclegan_model
.
reconstructed_y
)
mock_eval
.
add_image_comparison_summaries
.
assert_called_once
()
mock_eval
.
add_gan_model_image_summaries
.
assert_called_once
()
@
mock
.
patch
.
object
(
train
.
networks
,
'generator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
.
networks
,
'discriminator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
...
...
research/gan/image_compression/data_provider_test.py
View file @
4f7074f6
...
...
@@ -20,6 +20,8 @@ from __future__ import print_function
import
os
from
absl
import
flags
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -27,11 +29,14 @@ import tensorflow as tf
import
data_provider
class
DataProviderTest
(
tf
.
test
.
TestCase
):
class
DataProviderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_data_provider_helper
(
self
,
split_name
):
@
parameterized
.
named_parameters
(
(
'train'
,
'train'
),
(
'validation'
,
'validation'
))
def
test_data_provider
(
self
,
split_name
):
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/'
)
batch_size
=
3
...
...
@@ -49,12 +54,6 @@ class DataProviderTest(tf.test.TestCase):
# Check range.
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_data_provider_train
(
self
):
self
.
_test_data_provider_helper
(
'train'
)
def
test_data_provider_validation
(
self
):
self
.
_test_data_provider_helper
(
'validation'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/gan/image_compression/eval.py
View file @
4f7074f6
...
...
@@ -20,13 +20,14 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
import
networks
import
summaries
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
...
...
@@ -98,4 +99,4 @@ def main(_, run_eval_loop=True):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
()
research/gan/image_compression/train.py
View file @
4f7074f6
...
...
@@ -20,7 +20,8 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
data_provider
...
...
@@ -29,7 +30,6 @@ import summaries
tfgan
=
tf
.
contrib
.
gan
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
...
...
@@ -212,6 +212,6 @@ def _get_gan_model(generator_inputs, generated_data, real_data,
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
tf
.
app
.
run
()
research/gan/image_compression/train_test.py
View file @
4f7074f6
...
...
@@ -19,17 +19,22 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
flags
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
train
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
class
TrainTest
(
tf
.
test
.
TestCase
):
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
weight_factor
):
@
parameterized
.
named_parameters
(
(
'NoAdversarialLoss'
,
0.0
),
(
'AdversarialLoss'
,
1.0
))
def
test_build_graph
(
self
,
weight_factor
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
weight_factor
=
weight_factor
...
...
@@ -45,12 +50,6 @@ class TrainTest(tf.test.TestCase):
mock_data_provider
.
provide_data
.
return_value
=
mock_imgs
train
.
main
(
None
)
def
test_build_graph_noadversarialloss
(
self
):
self
.
_test_build_graph_helper
(
0.0
)
def
test_build_graph_adversarialloss
(
self
):
self
.
_test_build_graph_helper
(
1.0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
...
...
research/gan/mnist/conditional_eval.py
View file @
4f7074f6
...
...
@@ -19,13 +19,14 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
import
networks
import
util
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
...
...
@@ -107,4 +108,4 @@ def _get_generator_inputs(num_images_per_class, num_classes, noise_dims):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/gan/mnist/conditional_eval_test.py
View file @
4f7074f6
...
...
@@ -18,15 +18,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
absl.testing
import
absltest
import
conditional_eval
class
ConditionalEvalTest
(
tf
.
test
.
TestCase
):
class
ConditionalEvalTest
(
absl
test
.
TestCase
):
def
test_build_graph
(
self
):
conditional_eval
.
main
(
None
,
run_eval_loop
=
False
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
absl
test
.
main
()
research/gan/mnist/data_provider_test.py
View file @
4f7074f6
...
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
os
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
...
...
@@ -30,7 +31,7 @@ class DataProviderTest(tf.test.TestCase):
def
test_mnist_data_reading
(
self
):
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/mnist/testdata'
)
batch_size
=
5
...
...
research/gan/mnist/eval.py
View file @
4f7074f6
...
...
@@ -20,13 +20,14 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
data_provider
import
networks
import
util
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
tfgan
=
tf
.
contrib
.
gan
...
...
@@ -100,4 +101,4 @@ def main(_, run_eval_loop=True):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/gan/mnist/eval_test.py
View file @
4f7074f6
...
...
@@ -20,21 +20,21 @@ from __future__ import print_function
import
tensorflow
as
tf
from
absl
import
flags
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
import
eval
# pylint:disable=redefined-builtin
class
EvalTest
(
tf
.
test
.
TestCase
):
class
EvalTest
(
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
eval_real_images
):
tf
.
flags
.
FLAGS
.
eval_real_images
=
eval_real_images
@
parameterized
.
named_parameters
(
(
'RealData'
,
True
),
(
'GeneratedData'
,
False
))
def
test_build_graph
(
self
,
eval_real_images
):
flags
.
FLAGS
.
eval_real_images
=
eval_real_images
eval
.
main
(
None
,
run_eval_loop
=
False
)
def
test_build_graph_realdata
(
self
):
self
.
_test_build_graph_helper
(
True
)
def
test_build_graph_generateddata
(
self
):
self
.
_test_build_graph_helper
(
False
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
absl
test
.
main
()
research/gan/mnist/infogan_eval.py
View file @
4f7074f6
...
...
@@ -26,6 +26,8 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -34,7 +36,6 @@ import data_provider
import
networks
import
util
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
...
...
@@ -156,4 +157,4 @@ def _get_write_image_ops(eval_dir, filename, images):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment