Commit a6bbd3fa authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Minor readability improvements.

PiperOrigin-RevId: 364219111
parent 1e9cbdce
......@@ -112,8 +112,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
input_tensor_shape = input_shape[0] if (
len(input_shape) == 2) else input_shape
input_tensor_shape = tf.TensorShape(input_tensor_shape)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError(
"TransformerScaffold expects a three-dimensional input of "
......@@ -170,6 +171,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
self._feedforward_block = None
# self._dropout_rate controls dropout rates at two places:
# after attention, and after FFN.
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment