Commit 3de127e1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 373569850
parent e5ed7441
...@@ -91,7 +91,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -91,7 +91,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dim = inner_dim self._inner_dim = inner_dim
self._inner_activation = inner_activation self._inner_activation = inner_activation
self._attention_dropout = attention_dropout self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
...@@ -195,9 +197,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -195,9 +197,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
"inner_activation": "inner_activation":
self._inner_activation, self._inner_activation,
"output_dropout": "output_dropout":
self._output_dropout, self._output_dropout_rate,
"attention_dropout": "attention_dropout":
self._attention_dropout, self._attention_dropout_rate,
"output_range": "output_range":
self._output_range, self._output_range,
"kernel_initializer": "kernel_initializer":
......
...@@ -98,6 +98,46 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock): ...@@ -98,6 +98,46 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock):
attention_initializer=attention_initializer, attention_initializer=attention_initializer,
**kwargs) **kwargs)
def get_config(self):
return {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._inner_dim,
"intermediate_activation":
self._inner_activation,
"dropout_rate":
self._attention_dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable @gin.configurable
...@@ -275,8 +315,10 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -275,8 +315,10 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, name="output_layer_norm",
epsilon=self._norm_epsilon, dtype="float32") axis=-1,
epsilon=self._norm_epsilon,
dtype="float32")
super().build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment