Commit 611abda1 authored by exx8's avatar exx8
Browse files

add a get_config implementation.

parent 50369291
...@@ -172,6 +172,22 @@ class Encoder(tf.keras.layers.Layer): ...@@ -172,6 +172,22 @@ class Encoder(tf.keras.layers.Layer):
x = encoder_layer(x, training=training) x = encoder_layer(x, training=training)
x = self._norm(x) x = self._norm(x)
return x return x
def get_config(self):
config = {
'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads,
'dropout_rate': self._dropout_rate,
'attention_dropout_rate': self._attention_dropout_rate,
'kernel_regularizer': self._kernel_regularizer,
'inputs_positions': self._inputs_positions,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
class VisionTransformer(tf.keras.Model): class VisionTransformer(tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment