Commit df15a276 authored by Taylor Robie's avatar Taylor Robie Committed by A. Unique TensorFlower
Browse files

Use public endpoints for trainable in BatchNormRelu.

PiperOrigin-RevId: 291081423
parent 5ed215b2
......@@ -48,8 +48,8 @@ class BatchNormRelu(tf.keras.layers.Layer):
fused: `bool` fused option in batch normalziation.
name: `str` name for the operation.
"""
super(BatchNormRelu, self).__init__(trainable=trainable)
self._use_relu = relu
self._trainable = trainable
if init_zero:
gamma_initializer = tf.keras.initializers.Zeros()
else:
......@@ -76,7 +76,7 @@ class BatchNormRelu(tf.keras.layers.Layer):
"""
# We will need to keep training=None by default, so that it can be inherit
# from keras.Model.training
if is_training and self._trainable:
if is_training and self.trainable:
is_training = True
inputs = self._batch_norm_op(inputs, training=is_training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment