Commit ff110c3b authored by Ellery Wulczyn's avatar Ellery Wulczyn Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 469579875
parent 9661bd57
......@@ -228,7 +228,8 @@ class Encoder(tf.keras.layers.Layer):
return x
def get_config(self):
config = {
config = super().get_config()
updates = {
'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads,
......@@ -239,9 +240,11 @@ class Encoder(tf.keras.layers.Layer):
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed,
'pos_embed_origin_shape': self._pos_embed_origin_shape,
'pos_embed_target_shape': self._pos_embed_target_shape,
}
base_config = super().get_config()
return base_config.update(config)
config.update(updates)
return config
class VisionTransformer(tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment