Commit cf7d71e0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 373569850
parent 371cda4e
......@@ -91,7 +91,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
......@@ -195,9 +197,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout,
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout,
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
......
......@@ -98,6 +98,46 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock):
attention_initializer=attention_initializer,
**kwargs)
def get_config(self):
return {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._inner_dim,
"intermediate_activation":
self._inner_activation,
"dropout_rate":
self._attention_dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
......@@ -275,8 +315,10 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1,
epsilon=self._norm_epsilon, dtype="float32")
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype="float32")
super().build(input_shape)
def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment