"vscode:/vscode.git/clone" did not exist on "7dcf1b7cc52d8311a0d2efac5e6f77df0ad7b304"
Commit 84577d6d authored by Mikael Souza's avatar Mikael Souza
Browse files

Adding fused_batch_norm parameter

parent 23b5b422
...@@ -46,7 +46,7 @@ def generator(noise, is_training=True): ...@@ -46,7 +46,7 @@ def generator(noise, is_training=True):
Returns: Returns:
A single Tensor with a batch of generated CIFAR images. A single Tensor with a batch of generated CIFAR images.
""" """
images, _ = dcgan.generator(noise, is_training=is_training) images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
# Make sure output lies between [-1, 1]. # Make sure output lies between [-1, 1].
return tf.tanh(images) return tf.tanh(images)
...@@ -68,7 +68,7 @@ def conditional_generator(inputs, is_training=True): ...@@ -68,7 +68,7 @@ def conditional_generator(inputs, is_training=True):
noise, one_hot_labels = inputs noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels) noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
images, _ = dcgan.generator(noise, is_training=is_training) images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
# Make sure output lies between [-1, 1]. # Make sure output lies between [-1, 1].
return tf.tanh(images) return tf.tanh(images)
...@@ -94,7 +94,7 @@ def discriminator(img, unused_conditioning, is_training=True): ...@@ -94,7 +94,7 @@ def discriminator(img, unused_conditioning, is_training=True):
images are real. The output can lie in [-inf, inf], with positive values images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real. indicating high confidence that the images are real.
""" """
logits, _ = dcgan.discriminator(img, is_training=is_training) logits, _ = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
return logits return logits
...@@ -118,7 +118,7 @@ def conditional_discriminator(img, conditioning, is_training=True): ...@@ -118,7 +118,7 @@ def conditional_discriminator(img, conditioning, is_training=True):
images are real. The output can lie in [-inf, inf], with positive values images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real. indicating high confidence that the images are real.
""" """
logits, end_points = dcgan.discriminator(img, is_training=is_training) logits, end_points = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
# Condition the last convolution layer. # Condition the last convolution layer.
_, one_hot_labels = conditioning _, one_hot_labels = conditioning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment