Commit 1c5dca7f authored by xinliupitt's avatar xinliupitt
Browse files

get_config and doc

parent 3e0fa932
......@@ -49,9 +49,11 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers.
layers. If set False, output of attention and intermediate dense layers
is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
......@@ -200,7 +202,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
}
base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -281,9 +289,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers.
layers. If set False, output of attention and intermediate dense layers
is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
......@@ -404,6 +414,47 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"cross_attention_cls":
self._cross_attention_cls
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment