Commit 8fb0438d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10564 from exx8:master

PiperOrigin-RevId: 437890828
parents 428a156b 611abda1
......@@ -173,6 +173,22 @@ class Encoder(tf.keras.layers.Layer):
x = self._norm(x)
return x
def get_config(self):
config = {
'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads,
'dropout_rate': self._dropout_rate,
'attention_dropout_rate': self._attention_dropout_rate,
'kernel_regularizer': self._kernel_regularizer,
'inputs_positions': self._inputs_positions,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed,
}
base_config = super().get_config()
return base_config.update(config)
class VisionTransformer(tf.keras.Model):
"""Class to build VisionTransformer family model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment