Unverified Commit 0783f1cf authored by Joel Shor's avatar Joel Shor Committed by GitHub
Browse files

Merge pull request #5227 from mikaelsouza/adding-fuse-batch-norm-parameter

Added fused_batch_norm parameter
parents 23b5b422 84577d6d
...@@ -46,7 +46,7 @@ def generator(noise, is_training=True): ...@@ -46,7 +46,7 @@ def generator(noise, is_training=True):
Returns: Returns:
A single Tensor with a batch of generated CIFAR images. A single Tensor with a batch of generated CIFAR images.
""" """
images, _ = dcgan.generator(noise, is_training=is_training) images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
# Make sure output lies between [-1, 1]. # Make sure output lies between [-1, 1].
return tf.tanh(images) return tf.tanh(images)
...@@ -68,7 +68,7 @@ def conditional_generator(inputs, is_training=True): ...@@ -68,7 +68,7 @@ def conditional_generator(inputs, is_training=True):
noise, one_hot_labels = inputs noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels) noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
images, _ = dcgan.generator(noise, is_training=is_training) images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
# Make sure output lies between [-1, 1]. # Make sure output lies between [-1, 1].
return tf.tanh(images) return tf.tanh(images)
...@@ -94,7 +94,7 @@ def discriminator(img, unused_conditioning, is_training=True): ...@@ -94,7 +94,7 @@ def discriminator(img, unused_conditioning, is_training=True):
images are real. The output can lie in [-inf, inf], with positive values images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real. indicating high confidence that the images are real.
""" """
logits, _ = dcgan.discriminator(img, is_training=is_training) logits, _ = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
return logits return logits
...@@ -118,7 +118,7 @@ def conditional_discriminator(img, conditioning, is_training=True): ...@@ -118,7 +118,7 @@ def conditional_discriminator(img, conditioning, is_training=True):
images are real. The output can lie in [-inf, inf], with positive values images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real. indicating high confidence that the images are real.
""" """
logits, end_points = dcgan.discriminator(img, is_training=is_training) logits, end_points = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
# Condition the last convolution layer. # Condition the last convolution layer.
_, one_hot_labels = conditioning _, one_hot_labels = conditioning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment