Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4f7074f6
Commit
4f7074f6
authored
Jun 05, 2018
by
Joel Shor
Committed by
Joel Shor
Jun 05, 2018
Browse files
Project import generated by Copybara.
PiperOrigin-RevId: 199251174
parent
30cf3752
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
94 additions
and
94 deletions
+94
-94
research/gan/cifar/data_provider_test.py
research/gan/cifar/data_provider_test.py
+2
-1
research/gan/cifar/eval.py
research/gan/cifar/eval.py
+4
-3
research/gan/cifar/eval_test.py
research/gan/cifar/eval_test.py
+9
-12
research/gan/cifar/train.py
research/gan/cifar/train.py
+3
-2
research/gan/cifar/train_test.py
research/gan/cifar/train_test.py
+9
-11
research/gan/cyclegan/data_provider_test.py
research/gan/cyclegan/data_provider_test.py
+9
-7
research/gan/cyclegan/inference_demo.py
research/gan/cyclegan/inference_demo.py
+3
-2
research/gan/cyclegan/inference_demo_test.py
research/gan/cyclegan/inference_demo_test.py
+5
-3
research/gan/cyclegan/train.py
research/gan/cyclegan/train.py
+2
-6
research/gan/cyclegan/train_test.py
research/gan/cyclegan/train_test.py
+2
-4
research/gan/image_compression/data_provider_test.py
research/gan/image_compression/data_provider_test.py
+8
-9
research/gan/image_compression/eval.py
research/gan/image_compression/eval.py
+3
-2
research/gan/image_compression/train.py
research/gan/image_compression/train.py
+3
-3
research/gan/image_compression/train_test.py
research/gan/image_compression/train_test.py
+8
-9
research/gan/mnist/conditional_eval.py
research/gan/mnist/conditional_eval.py
+3
-2
research/gan/mnist/conditional_eval_test.py
research/gan/mnist/conditional_eval_test.py
+3
-3
research/gan/mnist/data_provider_test.py
research/gan/mnist/data_provider_test.py
+2
-1
research/gan/mnist/eval.py
research/gan/mnist/eval.py
+3
-2
research/gan/mnist/eval_test.py
research/gan/mnist/eval_test.py
+10
-10
research/gan/mnist/infogan_eval.py
research/gan/mnist/infogan_eval.py
+3
-2
No files found.
research/gan/cifar/data_provider_test.py
View file @
4f7074f6
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
os
import
os
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -31,7 +32,7 @@ class DataProviderTest(tf.test.TestCase):
...
@@ -31,7 +32,7 @@ class DataProviderTest(tf.test.TestCase):
def
test_cifar10_train_set
(
self
):
def
test_cifar10_train_set
(
self
):
dataset_dir
=
os
.
path
.
join
(
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cifar/testdata'
)
'google3/third_party/tensorflow_models/gan/cifar/testdata'
)
batch_size
=
4
batch_size
=
4
...
...
research/gan/cifar/eval.py
View file @
4f7074f6
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
...
@@ -25,8 +27,7 @@ import networks
...
@@ -25,8 +27,7 @@ import networks
import
util
import
util
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
tf
.
flags
.
FLAGS
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
...
@@ -155,4 +156,4 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
...
@@ -155,4 +156,4 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/gan/cifar/eval_test.py
View file @
4f7074f6
...
@@ -19,16 +19,22 @@ from __future__ import division
...
@@ -19,16 +19,22 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
import
eval
# pylint:disable=redefined-builtin
import
eval
# pylint:disable=redefined-builtin
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
mock
=
tf
.
test
.
mock
class
EvalTest
(
tf
.
test
.
TestCase
):
class
EvalTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
eval_real_images
,
conditional_eval
):
@
parameterized
.
named_parameters
(
(
'RealData'
,
True
,
False
),
(
'GeneratedData'
,
False
,
False
),
(
'GeneratedDataConditional'
,
False
,
True
))
def
test_build_graph
(
self
,
eval_real_images
,
conditional_eval
):
FLAGS
.
eval_real_images
=
eval_real_images
FLAGS
.
eval_real_images
=
eval_real_images
FLAGS
.
conditional_eval
=
conditional_eval
FLAGS
.
conditional_eval
=
conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are
# Mock `frechet_inception_distance` and `inception_score`, which are
...
@@ -40,15 +46,6 @@ class EvalTest(tf.test.TestCase):
...
@@ -40,15 +46,6 @@ class EvalTest(tf.test.TestCase):
mock_iscore
.
return_value
=
1.0
mock_iscore
.
return_value
=
1.0
eval
.
main
(
None
,
run_eval_loop
=
False
)
eval
.
main
(
None
,
run_eval_loop
=
False
)
def
test_build_graph_realdata
(
self
):
self
.
_test_build_graph_helper
(
True
,
False
)
def
test_build_graph_generateddata
(
self
):
self
.
_test_build_graph_helper
(
False
,
False
)
def
test_build_graph_generateddataconditional
(
self
):
self
.
_test_build_graph_helper
(
False
,
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/gan/cifar/train.py
View file @
4f7074f6
...
@@ -18,6 +18,8 @@ from __future__ import division
...
@@ -18,6 +18,8 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
...
@@ -25,7 +27,6 @@ import networks
...
@@ -25,7 +27,6 @@ import networks
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
flags
=
tf
.
flags
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
...
@@ -173,6 +174,6 @@ def _optimizer(gen_lr, dis_lr, use_sync_replicas):
...
@@ -173,6 +174,6 @@ def _optimizer(gen_lr, dis_lr, use_sync_replicas):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
tf
.
app
.
run
()
tf
.
app
.
run
()
research/gan/cifar/train_test.py
View file @
4f7074f6
...
@@ -19,17 +19,23 @@ from __future__ import division
...
@@ -19,17 +19,23 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
train
import
train
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
mock
=
tf
.
test
.
mock
class
TrainTest
(
tf
.
test
.
TestCase
):
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
conditional
,
use_sync_replicas
):
@
parameterized
.
named_parameters
(
(
'Unconditional'
,
False
,
False
),
(
'Conditional'
,
True
,
False
),
(
'SyncReplicas'
,
False
,
True
))
def
test_build_graph_helper
(
self
,
conditional
,
use_sync_replicas
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
conditional
=
conditional
FLAGS
.
conditional
=
conditional
FLAGS
.
use_sync_replicas
=
use_sync_replicas
FLAGS
.
use_sync_replicas
=
use_sync_replicas
...
@@ -45,14 +51,6 @@ class TrainTest(tf.test.TestCase):
...
@@ -45,14 +51,6 @@ class TrainTest(tf.test.TestCase):
mock_imgs
,
mock_lbls
,
None
,
None
)
mock_imgs
,
mock_lbls
,
None
,
None
)
train
.
main
(
None
)
train
.
main
(
None
)
def
test_build_graph_unconditional
(
self
):
self
.
_test_build_graph_helper
(
False
,
False
)
def
test_build_graph_conditional
(
self
):
self
.
_test_build_graph_helper
(
True
,
False
)
def
test_build_graph_syncreplicas
(
self
):
self
.
_test_build_graph_helper
(
False
,
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/gan/cyclegan/data_provider_test.py
View file @
4f7074f6
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
os
import
os
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -31,6 +32,12 @@ mock = tf.test.mock
...
@@ -31,6 +32,12 @@ mock = tf.test.mock
class
DataProviderTest
(
tf
.
test
.
TestCase
):
class
DataProviderTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
DataProviderTest
,
self
).
setUp
()
self
.
testdata_dir
=
os
.
path
.
join
(
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata'
)
def
test_normalize_image
(
self
):
def
test_normalize_image
(
self
):
image
=
tf
.
random_uniform
(
shape
=
(
8
,
8
,
3
),
maxval
=
256
,
dtype
=
tf
.
int32
)
image
=
tf
.
random_uniform
(
shape
=
(
8
,
8
,
3
),
maxval
=
256
,
dtype
=
tf
.
int32
)
rescaled_image
=
data_provider
.
normalize_image
(
image
)
rescaled_image
=
data_provider
.
normalize_image
(
image
)
...
@@ -51,13 +58,8 @@ class DataProviderTest(tf.test.TestCase):
...
@@ -51,13 +58,8 @@ class DataProviderTest(tf.test.TestCase):
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch2
).
shape
)
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch2
).
shape
)
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch3
).
shape
)
self
.
assertTupleEqual
((
10
,
10
,
3
),
sess
.
run
(
patch3
).
shape
)
def
_get_testdata_dir
(
self
):
return
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata'
)
def
test_custom_dataset_provider
(
self
):
def
test_custom_dataset_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
_get_
testdata_dir
()
,
'*.jpg'
)
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
batch_size
=
3
patch_size
=
8
patch_size
=
8
images
=
data_provider
.
_provide_custom_dataset
(
images
=
data_provider
.
_provide_custom_dataset
(
...
@@ -75,7 +77,7 @@ class DataProviderTest(tf.test.TestCase):
...
@@ -75,7 +77,7 @@ class DataProviderTest(tf.test.TestCase):
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_custom_datasets_provider
(
self
):
def
test_custom_datasets_provider
(
self
):
file_pattern
=
os
.
path
.
join
(
self
.
_get_
testdata_dir
()
,
'*.jpg'
)
file_pattern
=
os
.
path
.
join
(
self
.
testdata_dir
,
'*.jpg'
)
batch_size
=
3
batch_size
=
3
patch_size
=
8
patch_size
=
8
images_list
=
data_provider
.
provide_custom_datasets
(
images_list
=
data_provider
.
provide_custom_datasets
(
...
...
research/gan/cyclegan/inference_demo.py
View file @
4f7074f6
...
@@ -21,6 +21,8 @@ from __future__ import print_function
...
@@ -21,6 +21,8 @@ from __future__ import print_function
import
os
import
os
from
absl
import
app
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -28,7 +30,6 @@ import tensorflow as tf
...
@@ -28,7 +30,6 @@ import tensorflow as tf
import
data_provider
import
data_provider
import
networks
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
flags
.
DEFINE_string
(
'checkpoint_path'
,
''
,
flags
.
DEFINE_string
(
'checkpoint_path'
,
''
,
...
@@ -147,4 +148,4 @@ def main(_):
...
@@ -147,4 +148,4 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
()
research/gan/cyclegan/inference_demo_test.py
View file @
4f7074f6
...
@@ -5,8 +5,10 @@ from __future__ import division
...
@@ -5,8 +5,10 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
tensorflow
as
tf
import
tensorflow
as
tf
import
inference_demo
import
inference_demo
...
@@ -59,7 +61,7 @@ class InferenceDemoTest(tf.test.TestCase):
...
@@ -59,7 +61,7 @@ class InferenceDemoTest(tf.test.TestCase):
# Create inference graph
# Create inference graph
tf
.
reset_default_graph
()
tf
.
reset_default_graph
()
FLAGS
.
patch_dim
=
FLAGS
.
patch_size
FLAGS
.
patch_dim
=
FLAGS
.
patch_size
tf
.
logging
.
info
(
'dir_path:
{}'
.
format
(
os
.
listdir
(
self
.
_export_dir
))
)
logging
.
info
(
'dir_path:
%s'
,
os
.
listdir
(
self
.
_export_dir
))
FLAGS
.
checkpoint_path
=
self
.
_ckpt_path
FLAGS
.
checkpoint_path
=
self
.
_ckpt_path
FLAGS
.
image_set_x_glob
=
self
.
_image_glob
FLAGS
.
image_set_x_glob
=
self
.
_image_glob
FLAGS
.
image_set_y_glob
=
self
.
_image_glob
FLAGS
.
image_set_y_glob
=
self
.
_image_glob
...
@@ -67,7 +69,7 @@ class InferenceDemoTest(tf.test.TestCase):
...
@@ -67,7 +69,7 @@ class InferenceDemoTest(tf.test.TestCase):
FLAGS
.
generated_y_dir
=
self
.
_geny_dir
FLAGS
.
generated_y_dir
=
self
.
_geny_dir
inference_demo
.
main
(
None
)
inference_demo
.
main
(
None
)
tf
.
logging
.
info
(
'gen x:
{}'
.
format
(
os
.
listdir
(
self
.
_genx_dir
))
)
logging
.
info
(
'gen x:
%s'
,
os
.
listdir
(
self
.
_genx_dir
))
# Check that the image names match
# Check that the image names match
self
.
assertSetEqual
(
self
.
assertSetEqual
(
...
@@ -84,7 +86,7 @@ class InferenceDemoTest(tf.test.TestCase):
...
@@ -84,7 +86,7 @@ class InferenceDemoTest(tf.test.TestCase):
self
.
assertRealisticImage
(
image_path
)
self
.
assertRealisticImage
(
image_path
)
def
assertRealisticImage
(
self
,
image_path
):
def
assertRealisticImage
(
self
,
image_path
):
tf
.
logging
.
info
(
'Testing
{}
for realism.'
.
format
(
image_path
)
)
logging
.
info
(
'Testing
%s
for realism.'
,
image_path
)
# If the normalization is off or forgotten, then the generated image is
# If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved.
# all one pixel value. This tests that different pixel values are achieved.
input_np
=
np
.
asarray
(
PIL
.
Image
.
open
(
image_path
))
input_np
=
np
.
asarray
(
PIL
.
Image
.
open
(
image_path
))
...
...
research/gan/cyclegan/train.py
View file @
4f7074f6
...
@@ -19,13 +19,12 @@ from __future__ import division
...
@@ -19,13 +19,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
import
networks
import
networks
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
...
@@ -87,10 +86,7 @@ def _define_model(images_x, images_y):
...
@@ -87,10 +86,7 @@ def _define_model(images_x, images_y):
data_y
=
images_y
)
data_y
=
images_y
)
# Add summaries for generated images.
# Add summaries for generated images.
tfgan
.
eval
.
add_image_comparison_summaries
(
tfgan
.
eval
.
add_cyclegan_image_summaries
(
cyclegan_model
)
cyclegan_model
,
num_comparisons
=
3
,
display_diffs
=
False
)
tfgan
.
eval
.
add_gan_model_image_summaries
(
cyclegan_model
,
grid_size
=
int
(
np
.
sqrt
(
FLAGS
.
batch_size
)))
return
cyclegan_model
return
cyclegan_model
...
...
research/gan/cyclegan/train_test.py
View file @
4f7074f6
...
@@ -19,12 +19,13 @@ from __future__ import division
...
@@ -19,12 +19,13 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
train
import
train
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
mock
=
tf
.
test
.
mock
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
...
@@ -60,9 +61,6 @@ class TrainTest(tf.test.TestCase):
...
@@ -60,9 +61,6 @@ class TrainTest(tf.test.TestCase):
self
.
assertShapeEqual
(
images_x_np
,
cyclegan_model
.
reconstructed_x
)
self
.
assertShapeEqual
(
images_x_np
,
cyclegan_model
.
reconstructed_x
)
self
.
assertShapeEqual
(
images_y_np
,
cyclegan_model
.
reconstructed_y
)
self
.
assertShapeEqual
(
images_y_np
,
cyclegan_model
.
reconstructed_y
)
mock_eval
.
add_image_comparison_summaries
.
assert_called_once
()
mock_eval
.
add_gan_model_image_summaries
.
assert_called_once
()
@
mock
.
patch
.
object
(
train
.
networks
,
'generator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
.
networks
,
'generator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
.
networks
,
'discriminator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
train
.
networks
,
'discriminator'
,
autospec
=
True
)
@
mock
.
patch
.
object
(
@
mock
.
patch
.
object
(
...
...
research/gan/image_compression/data_provider_test.py
View file @
4f7074f6
...
@@ -20,6 +20,8 @@ from __future__ import print_function
...
@@ -20,6 +20,8 @@ from __future__ import print_function
import
os
import
os
from
absl
import
flags
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -27,11 +29,14 @@ import tensorflow as tf
...
@@ -27,11 +29,14 @@ import tensorflow as tf
import
data_provider
import
data_provider
class
DataProviderTest
(
tf
.
test
.
TestCase
):
class
DataProviderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_data_provider_helper
(
self
,
split_name
):
@
parameterized
.
named_parameters
(
(
'train'
,
'train'
),
(
'validation'
,
'validation'
))
def
test_data_provider
(
self
,
split_name
):
dataset_dir
=
os
.
path
.
join
(
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/'
)
'google3/third_party/tensorflow_models/gan/image_compression/testdata/'
)
batch_size
=
3
batch_size
=
3
...
@@ -49,12 +54,6 @@ class DataProviderTest(tf.test.TestCase):
...
@@ -49,12 +54,6 @@ class DataProviderTest(tf.test.TestCase):
# Check range.
# Check range.
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
self
.
assertTrue
(
np
.
all
(
np
.
abs
(
images_out
)
<=
1.0
))
def
test_data_provider_train
(
self
):
self
.
_test_data_provider_helper
(
'train'
)
def
test_data_provider_validation
(
self
):
self
.
_test_data_provider_helper
(
'validation'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/gan/image_compression/eval.py
View file @
4f7074f6
...
@@ -20,13 +20,14 @@ from __future__ import print_function
...
@@ -20,13 +20,14 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
import
networks
import
networks
import
summaries
import
summaries
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
...
@@ -98,4 +99,4 @@ def main(_, run_eval_loop=True):
...
@@ -98,4 +99,4 @@ def main(_, run_eval_loop=True):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
()
research/gan/image_compression/train.py
View file @
4f7074f6
...
@@ -20,7 +20,8 @@ from __future__ import division
...
@@ -20,7 +20,8 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
...
@@ -29,7 +30,6 @@ import summaries
...
@@ -29,7 +30,6 @@ import summaries
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of images in each batch.'
)
...
@@ -212,6 +212,6 @@ def _get_gan_model(generator_inputs, generated_data, real_data,
...
@@ -212,6 +212,6 @@ def _get_gan_model(generator_inputs, generated_data, real_data,
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
tf
.
app
.
run
()
tf
.
app
.
run
()
research/gan/image_compression/train_test.py
View file @
4f7074f6
...
@@ -19,17 +19,22 @@ from __future__ import division
...
@@ -19,17 +19,22 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
train
import
train
FLAGS
=
tf
.
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
mock
=
tf
.
test
.
mock
mock
=
tf
.
test
.
mock
class
TrainTest
(
tf
.
test
.
TestCase
):
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
weight_factor
):
@
parameterized
.
named_parameters
(
(
'NoAdversarialLoss'
,
0.0
),
(
'AdversarialLoss'
,
1.0
))
def
test_build_graph
(
self
,
weight_factor
):
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
max_number_of_steps
=
0
FLAGS
.
weight_factor
=
weight_factor
FLAGS
.
weight_factor
=
weight_factor
...
@@ -45,12 +50,6 @@ class TrainTest(tf.test.TestCase):
...
@@ -45,12 +50,6 @@ class TrainTest(tf.test.TestCase):
mock_data_provider
.
provide_data
.
return_value
=
mock_imgs
mock_data_provider
.
provide_data
.
return_value
=
mock_imgs
train
.
main
(
None
)
train
.
main
(
None
)
def
test_build_graph_noadversarialloss
(
self
):
self
.
_test_build_graph_helper
(
0.0
)
def
test_build_graph_adversarialloss
(
self
):
self
.
_test_build_graph_helper
(
1.0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
...
...
research/gan/mnist/conditional_eval.py
View file @
4f7074f6
...
@@ -19,13 +19,14 @@ from __future__ import division
...
@@ -19,13 +19,14 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
import
networks
import
networks
import
util
import
util
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
...
@@ -107,4 +108,4 @@ def _get_generator_inputs(num_images_per_class, num_classes, noise_dims):
...
@@ -107,4 +108,4 @@ def _get_generator_inputs(num_images_per_class, num_classes, noise_dims):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/gan/mnist/conditional_eval_test.py
View file @
4f7074f6
...
@@ -18,15 +18,15 @@ from __future__ import absolute_import
...
@@ -18,15 +18,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
tensorflow
as
tf
from
absl.testing
import
absltest
import
conditional_eval
import
conditional_eval
class
ConditionalEvalTest
(
tf
.
test
.
TestCase
):
class
ConditionalEvalTest
(
absl
test
.
TestCase
):
def
test_build_graph
(
self
):
def
test_build_graph
(
self
):
conditional_eval
.
main
(
None
,
run_eval_loop
=
False
)
conditional_eval
.
main
(
None
,
run_eval_loop
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
absl
test
.
main
()
research/gan/mnist/data_provider_test.py
View file @
4f7074f6
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
os
import
os
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
...
@@ -30,7 +31,7 @@ class DataProviderTest(tf.test.TestCase):
...
@@ -30,7 +31,7 @@ class DataProviderTest(tf.test.TestCase):
def
test_mnist_data_reading
(
self
):
def
test_mnist_data_reading
(
self
):
dataset_dir
=
os
.
path
.
join
(
dataset_dir
=
os
.
path
.
join
(
tf
.
flags
.
FLAGS
.
test_srcdir
,
flags
.
FLAGS
.
test_srcdir
,
'google3/third_party/tensorflow_models/gan/mnist/testdata'
)
'google3/third_party/tensorflow_models/gan/mnist/testdata'
)
batch_size
=
5
batch_size
=
5
...
...
research/gan/mnist/eval.py
View file @
4f7074f6
...
@@ -20,13 +20,14 @@ from __future__ import print_function
...
@@ -20,13 +20,14 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
import
data_provider
import
data_provider
import
networks
import
networks
import
util
import
util
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
...
@@ -100,4 +101,4 @@ def main(_, run_eval_loop=True):
...
@@ -100,4 +101,4 @@ def main(_, run_eval_loop=True):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/gan/mnist/eval_test.py
View file @
4f7074f6
...
@@ -20,21 +20,21 @@ from __future__ import print_function
...
@@ -20,21 +20,21 @@ from __future__ import print_function
import
tensorflow
as
tf
from
absl
import
flags
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
import
eval
# pylint:disable=redefined-builtin
import
eval
# pylint:disable=redefined-builtin
class
EvalTest
(
tf
.
test
.
TestCase
):
class
EvalTest
(
parameterized
.
TestCase
):
def
_test_build_graph_helper
(
self
,
eval_real_images
):
@
parameterized
.
named_parameters
(
tf
.
flags
.
FLAGS
.
eval_real_images
=
eval_real_images
(
'RealData'
,
True
),
(
'GeneratedData'
,
False
))
def
test_build_graph
(
self
,
eval_real_images
):
flags
.
FLAGS
.
eval_real_images
=
eval_real_images
eval
.
main
(
None
,
run_eval_loop
=
False
)
eval
.
main
(
None
,
run_eval_loop
=
False
)
def
test_build_graph_realdata
(
self
):
self
.
_test_build_graph_helper
(
True
)
def
test_build_graph_generateddata
(
self
):
self
.
_test_build_graph_helper
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
absl
test
.
main
()
research/gan/mnist/infogan_eval.py
View file @
4f7074f6
...
@@ -26,6 +26,8 @@ from __future__ import division
...
@@ -26,6 +26,8 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -34,7 +36,6 @@ import data_provider
...
@@ -34,7 +36,6 @@ import data_provider
import
networks
import
networks
import
util
import
util
flags
=
tf
.
flags
tfgan
=
tf
.
contrib
.
gan
tfgan
=
tf
.
contrib
.
gan
...
@@ -156,4 +157,4 @@ def _get_write_image_ops(eval_dir, filename, images):
...
@@ -156,4 +157,4 @@ def _get_write_image_ops(eval_dir, filename, images):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment