"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "dff0f0c1143876c374f2f6887fb7a59281344a8f"
Unverified Commit 2d5e95a3 authored by Joel Shor's avatar Joel Shor Committed by GitHub
Browse files

Merge pull request #4181 from joel-shor/master

Made `generate_cifar10_tfrecords.py` python3 compatible. Fixes #3209. Fixes #3428
parents a1adc50b 4f7074f6
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -31,7 +32,7 @@ class DataProviderTest(tf.test.TestCase): ...@@ -31,7 +32,7 @@ class DataProviderTest(tf.test.TestCase):
def test_cifar10_train_set(self): def test_cifar10_train_set(self):
dataset_dir = os.path.join( dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir, flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata') 'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4 batch_size = 4
......
...@@ -18,6 +18,8 @@ from __future__ import absolute_import ...@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
...@@ -25,8 +27,7 @@ import networks ...@@ -25,8 +27,7 @@ import networks
import util import util
flags = tf.flags FLAGS = flags.FLAGS
FLAGS = tf.flags.FLAGS
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.') flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
...@@ -155,4 +156,4 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes): ...@@ -155,4 +156,4 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -19,16 +19,22 @@ from __future__ import division ...@@ -19,16 +19,22 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
import eval # pylint:disable=redefined-builtin import eval # pylint:disable=redefined-builtin
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
class EvalTest(tf.test.TestCase): class EvalTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, eval_real_images, conditional_eval): @parameterized.named_parameters(
('RealData', True, False),
('GeneratedData', False, False),
('GeneratedDataConditional', False, True))
def test_build_graph(self, eval_real_images, conditional_eval):
FLAGS.eval_real_images = eval_real_images FLAGS.eval_real_images = eval_real_images
FLAGS.conditional_eval = conditional_eval FLAGS.conditional_eval = conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are # Mock `frechet_inception_distance` and `inception_score`, which are
...@@ -40,15 +46,6 @@ class EvalTest(tf.test.TestCase): ...@@ -40,15 +46,6 @@ class EvalTest(tf.test.TestCase):
mock_iscore.return_value = 1.0 mock_iscore.return_value = 1.0
eval.main(None, run_eval_loop=False) eval.main(None, run_eval_loop=False)
def test_build_graph_realdata(self):
self._test_build_graph_helper(True, False)
def test_build_graph_generateddata(self):
self._test_build_graph_helper(False, False)
def test_build_graph_generateddataconditional(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -18,6 +18,8 @@ from __future__ import division ...@@ -18,6 +18,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
...@@ -25,7 +27,6 @@ import networks ...@@ -25,7 +27,6 @@ import networks
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
flags = tf.flags
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.') flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
...@@ -173,6 +174,6 @@ def _optimizer(gen_lr, dis_lr, use_sync_replicas): ...@@ -173,6 +174,6 @@ def _optimizer(gen_lr, dis_lr, use_sync_replicas):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) logging.set_verbosity(logging.INFO)
tf.app.run() tf.app.run()
...@@ -19,17 +19,23 @@ from __future__ import division ...@@ -19,17 +19,23 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import train import train
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
class TrainTest(tf.test.TestCase): class TrainTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, conditional, use_sync_replicas): @parameterized.named_parameters(
('Unconditional', False, False),
('Conditional', True, False),
('SyncReplicas', False, True))
def test_build_graph_helper(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0 FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas FLAGS.use_sync_replicas = use_sync_replicas
...@@ -45,14 +51,6 @@ class TrainTest(tf.test.TestCase): ...@@ -45,14 +51,6 @@ class TrainTest(tf.test.TestCase):
mock_imgs, mock_lbls, None, None) mock_imgs, mock_lbls, None, None)
train.main(None) train.main(None)
def test_build_graph_unconditional(self):
self._test_build_graph_helper(False, False)
def test_build_graph_conditional(self):
self._test_build_graph_helper(True, False)
def test_build_graph_syncreplicas(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -31,6 +32,12 @@ mock = tf.test.mock ...@@ -31,6 +32,12 @@ mock = tf.test.mock
class DataProviderTest(tf.test.TestCase): class DataProviderTest(tf.test.TestCase):
def setUp(self):
super(DataProviderTest, self).setUp()
self.testdata_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
def test_normalize_image(self): def test_normalize_image(self):
image = tf.random_uniform(shape=(8, 8, 3), maxval=256, dtype=tf.int32) image = tf.random_uniform(shape=(8, 8, 3), maxval=256, dtype=tf.int32)
rescaled_image = data_provider.normalize_image(image) rescaled_image = data_provider.normalize_image(image)
...@@ -51,13 +58,8 @@ class DataProviderTest(tf.test.TestCase): ...@@ -51,13 +58,8 @@ class DataProviderTest(tf.test.TestCase):
self.assertTupleEqual((10, 10, 3), sess.run(patch2).shape) self.assertTupleEqual((10, 10, 3), sess.run(patch2).shape)
self.assertTupleEqual((10, 10, 3), sess.run(patch3).shape) self.assertTupleEqual((10, 10, 3), sess.run(patch3).shape)
def _get_testdata_dir(self):
return os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
def test_custom_dataset_provider(self): def test_custom_dataset_provider(self):
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg') file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3 batch_size = 3
patch_size = 8 patch_size = 8
images = data_provider._provide_custom_dataset( images = data_provider._provide_custom_dataset(
...@@ -75,7 +77,7 @@ class DataProviderTest(tf.test.TestCase): ...@@ -75,7 +77,7 @@ class DataProviderTest(tf.test.TestCase):
self.assertTrue(np.all(np.abs(images_out) <= 1.0)) self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_datasets_provider(self): def test_custom_datasets_provider(self):
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg') file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3 batch_size = 3
patch_size = 8 patch_size = 8
images_list = data_provider.provide_custom_datasets( images_list = data_provider.provide_custom_datasets(
......
...@@ -21,6 +21,8 @@ from __future__ import print_function ...@@ -21,6 +21,8 @@ from __future__ import print_function
import os import os
from absl import app
from absl import flags
import numpy as np import numpy as np
import PIL import PIL
import tensorflow as tf import tensorflow as tf
...@@ -28,7 +30,6 @@ import tensorflow as tf ...@@ -28,7 +30,6 @@ import tensorflow as tf
import data_provider import data_provider
import networks import networks
flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
flags.DEFINE_string('checkpoint_path', '', flags.DEFINE_string('checkpoint_path', '',
...@@ -147,4 +148,4 @@ def main(_): ...@@ -147,4 +148,4 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run()
...@@ -5,8 +5,10 @@ from __future__ import division ...@@ -5,8 +5,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from absl import logging
import numpy as np import numpy as np
import PIL import PIL
import tensorflow as tf import tensorflow as tf
import inference_demo import inference_demo
...@@ -59,7 +61,7 @@ class InferenceDemoTest(tf.test.TestCase): ...@@ -59,7 +61,7 @@ class InferenceDemoTest(tf.test.TestCase):
# Create inference graph # Create inference graph
tf.reset_default_graph() tf.reset_default_graph()
FLAGS.patch_dim = FLAGS.patch_size FLAGS.patch_dim = FLAGS.patch_size
tf.logging.info('dir_path: {}'.format(os.listdir(self._export_dir))) logging.info('dir_path: %s', os.listdir(self._export_dir))
FLAGS.checkpoint_path = self._ckpt_path FLAGS.checkpoint_path = self._ckpt_path
FLAGS.image_set_x_glob = self._image_glob FLAGS.image_set_x_glob = self._image_glob
FLAGS.image_set_y_glob = self._image_glob FLAGS.image_set_y_glob = self._image_glob
...@@ -67,7 +69,7 @@ class InferenceDemoTest(tf.test.TestCase): ...@@ -67,7 +69,7 @@ class InferenceDemoTest(tf.test.TestCase):
FLAGS.generated_y_dir = self._geny_dir FLAGS.generated_y_dir = self._geny_dir
inference_demo.main(None) inference_demo.main(None)
tf.logging.info('gen x: {}'.format(os.listdir(self._genx_dir))) logging.info('gen x: %s', os.listdir(self._genx_dir))
# Check that the image names match # Check that the image names match
self.assertSetEqual( self.assertSetEqual(
...@@ -84,7 +86,7 @@ class InferenceDemoTest(tf.test.TestCase): ...@@ -84,7 +86,7 @@ class InferenceDemoTest(tf.test.TestCase):
self.assertRealisticImage(image_path) self.assertRealisticImage(image_path)
def assertRealisticImage(self, image_path): def assertRealisticImage(self, image_path):
tf.logging.info('Testing {} for realism.'.format(image_path)) logging.info('Testing %s for realism.', image_path)
# If the normalization is off or forgotten, then the generated image is # If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved. # all one pixel value. This tests that different pixel values are achieved.
input_np = np.asarray(PIL.Image.open(image_path)) input_np = np.asarray(PIL.Image.open(image_path))
......
...@@ -19,13 +19,12 @@ from __future__ import division ...@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
import networks import networks
flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
...@@ -87,10 +86,7 @@ def _define_model(images_x, images_y): ...@@ -87,10 +86,7 @@ def _define_model(images_x, images_y):
data_y=images_y) data_y=images_y)
# Add summaries for generated images. # Add summaries for generated images.
tfgan.eval.add_image_comparison_summaries( tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)
cyclegan_model, num_comparisons=3, display_diffs=False)
tfgan.eval.add_gan_model_image_summaries(
cyclegan_model, grid_size=int(np.sqrt(FLAGS.batch_size)))
return cyclegan_model return cyclegan_model
......
...@@ -19,12 +19,13 @@ from __future__ import division ...@@ -19,12 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import train import train
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
...@@ -60,9 +61,6 @@ class TrainTest(tf.test.TestCase): ...@@ -60,9 +61,6 @@ class TrainTest(tf.test.TestCase):
self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x) self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x)
self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y) self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y)
mock_eval.add_image_comparison_summaries.assert_called_once()
mock_eval.add_gan_model_image_summaries.assert_called_once()
@mock.patch.object(train.networks, 'generator', autospec=True) @mock.patch.object(train.networks, 'generator', autospec=True)
@mock.patch.object(train.networks, 'discriminator', autospec=True) @mock.patch.object(train.networks, 'discriminator', autospec=True)
@mock.patch.object( @mock.patch.object(
......
...@@ -20,6 +20,8 @@ from __future__ import print_function ...@@ -20,6 +20,8 @@ from __future__ import print_function
import os import os
from absl import flags
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -27,11 +29,14 @@ import tensorflow as tf ...@@ -27,11 +29,14 @@ import tensorflow as tf
import data_provider import data_provider
class DataProviderTest(tf.test.TestCase): class DataProviderTest(tf.test.TestCase, parameterized.TestCase):
def _test_data_provider_helper(self, split_name): @parameterized.named_parameters(
('train', 'train'),
('validation', 'validation'))
def test_data_provider(self, split_name):
dataset_dir = os.path.join( dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir, flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/') 'google3/third_party/tensorflow_models/gan/image_compression/testdata/')
batch_size = 3 batch_size = 3
...@@ -49,12 +54,6 @@ class DataProviderTest(tf.test.TestCase): ...@@ -49,12 +54,6 @@ class DataProviderTest(tf.test.TestCase):
# Check range. # Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0)) self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_data_provider_train(self):
self._test_data_provider_helper('train')
def test_data_provider_validation(self):
self._test_data_provider_helper('validation')
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -20,13 +20,14 @@ from __future__ import print_function ...@@ -20,13 +20,14 @@ from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
import networks import networks
import summaries import summaries
flags = tf.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.') flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
...@@ -98,4 +99,4 @@ def main(_, run_eval_loop=True): ...@@ -98,4 +99,4 @@ def main(_, run_eval_loop=True):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run()
...@@ -20,7 +20,8 @@ from __future__ import division ...@@ -20,7 +20,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
...@@ -29,7 +30,6 @@ import summaries ...@@ -29,7 +30,6 @@ import summaries
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
flags = tf.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.') flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
...@@ -212,6 +212,6 @@ def _get_gan_model(generator_inputs, generated_data, real_data, ...@@ -212,6 +212,6 @@ def _get_gan_model(generator_inputs, generated_data, real_data,
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) logging.set_verbosity(logging.INFO)
tf.app.run() tf.app.run()
...@@ -19,17 +19,22 @@ from __future__ import division ...@@ -19,17 +19,22 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import train import train
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
class TrainTest(tf.test.TestCase): class TrainTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, weight_factor): @parameterized.named_parameters(
('NoAdversarialLoss', 0.0),
('AdversarialLoss', 1.0))
def test_build_graph(self, weight_factor):
FLAGS.max_number_of_steps = 0 FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor FLAGS.weight_factor = weight_factor
...@@ -45,12 +50,6 @@ class TrainTest(tf.test.TestCase): ...@@ -45,12 +50,6 @@ class TrainTest(tf.test.TestCase):
mock_data_provider.provide_data.return_value = mock_imgs mock_data_provider.provide_data.return_value = mock_imgs
train.main(None) train.main(None)
def test_build_graph_noadversarialloss(self):
self._test_build_graph_helper(0.0)
def test_build_graph_adversarialloss(self):
self._test_build_graph_helper(1.0)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
......
...@@ -19,13 +19,14 @@ from __future__ import division ...@@ -19,13 +19,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
import networks import networks
import util import util
flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
...@@ -107,4 +108,4 @@ def _get_generator_inputs(num_images_per_class, num_classes, noise_dims): ...@@ -107,4 +108,4 @@ def _get_generator_inputs(num_images_per_class, num_classes, noise_dims):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -18,15 +18,15 @@ from __future__ import absolute_import ...@@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf from absl.testing import absltest
import conditional_eval import conditional_eval
class ConditionalEvalTest(tf.test.TestCase): class ConditionalEvalTest(absltest.TestCase):
def test_build_graph(self): def test_build_graph(self):
conditional_eval.main(None, run_eval_loop=False) conditional_eval.main(None, run_eval_loop=False)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() absltest.main()
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
...@@ -30,7 +31,7 @@ class DataProviderTest(tf.test.TestCase): ...@@ -30,7 +31,7 @@ class DataProviderTest(tf.test.TestCase):
def test_mnist_data_reading(self): def test_mnist_data_reading(self):
dataset_dir = os.path.join( dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir, flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/mnist/testdata') 'google3/third_party/tensorflow_models/gan/mnist/testdata')
batch_size = 5 batch_size = 5
......
...@@ -20,13 +20,14 @@ from __future__ import print_function ...@@ -20,13 +20,14 @@ from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
import networks import networks
import util import util
flags = tf.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
...@@ -100,4 +101,4 @@ def main(_, run_eval_loop=True): ...@@ -100,4 +101,4 @@ def main(_, run_eval_loop=True):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -20,21 +20,21 @@ from __future__ import print_function ...@@ -20,21 +20,21 @@ from __future__ import print_function
import tensorflow as tf from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import eval # pylint:disable=redefined-builtin import eval # pylint:disable=redefined-builtin
class EvalTest(tf.test.TestCase): class EvalTest(parameterized.TestCase):
def _test_build_graph_helper(self, eval_real_images): @parameterized.named_parameters(
tf.flags.FLAGS.eval_real_images = eval_real_images ('RealData', True),
('GeneratedData', False))
def test_build_graph(self, eval_real_images):
flags.FLAGS.eval_real_images = eval_real_images
eval.main(None, run_eval_loop=False) eval.main(None, run_eval_loop=False)
def test_build_graph_realdata(self):
self._test_build_graph_helper(True)
def test_build_graph_generateddata(self):
self._test_build_graph_helper(False)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() absltest.main()
...@@ -26,6 +26,8 @@ from __future__ import division ...@@ -26,6 +26,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app
from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -34,7 +36,6 @@ import data_provider ...@@ -34,7 +36,6 @@ import data_provider
import networks import networks
import util import util
flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
...@@ -156,4 +157,4 @@ def _get_write_image_ops(eval_dir, filename, images): ...@@ -156,4 +157,4 @@ def _get_write_image_ops(eval_dir, filename, images):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment