Commit 1c5dca7f authored by xinliupitt's avatar xinliupitt
Browse files

get_config and doc

parent 3e0fa932
...@@ -49,9 +49,11 @@ class Transformer(tf.keras.layers.Layer): ...@@ -49,9 +49,11 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense norm_first: Whether to normalize inputs to attention and intermediate dense
layers. layers. If set False, output of attention and intermediate dense layers
is normalized.
norm_epsilon: Epsilon value to initialize normalization layers. norm_epsilon: Epsilon value to initialize normalization layers.
""" """
...@@ -200,7 +202,13 @@ class Transformer(tf.keras.layers.Layer): ...@@ -200,7 +202,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint": "kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint), tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
} }
base_config = super(Transformer, self).get_config() base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -281,9 +289,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -281,9 +289,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense norm_first: Whether to normalize inputs to attention and intermediate dense
layers. layers. If set False, output of attention and intermediate dense layers
is normalized.
norm_epsilon: Epsilon value to initialize normalization layers. norm_epsilon: Epsilon value to initialize normalization layers.
""" """
...@@ -404,6 +414,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -404,6 +414,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon) name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape) super(TransformerDecoderLayer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"cross_attention_cls":
self._cross_attention_cls
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self): def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block.""" """Gets layer objects that can make a Transformer encoder block."""
return [ return [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment