Commit 399c7b50 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Remove unused use_normalization variable, and move tf.split into if branch.

PiperOrigin-RevId: 385611054
parent 04c1e39f
......@@ -97,7 +97,6 @@ class ProjectionHead(tf.keras.layers.Layer):
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
......
......@@ -90,14 +90,15 @@ class SimCLRModel(tf.keras.Model):
if training and self._mode == PRETRAIN:
num_transforms = 2
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list = tf.split(
inputs, num_or_size_splits=num_transforms, axis=-1)
# (num_transforms * bsz, h, w, c)
features = tf.concat(features_list, 0)
else:
num_transforms = 1
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list = tf.split(inputs, num_or_size_splits=num_transforms, axis=-1)
# (num_transforms * bsz, h, w, c)
features = tf.concat(features_list, 0)
features = inputs
# Base network forward pass.
endpoints = self._backbone(features, training=training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment