Commit 4862f9d3 authored by Ellery Wulczyn's avatar Ellery Wulczyn Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 469579875
parent 4aafcbfe
...@@ -228,7 +228,8 @@ class Encoder(tf.keras.layers.Layer): ...@@ -228,7 +228,8 @@ class Encoder(tf.keras.layers.Layer):
return x return x
def get_config(self): def get_config(self):
config = { config = super().get_config()
updates = {
'num_layers': self._num_layers, 'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim, 'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads, 'num_heads': self._num_heads,
...@@ -239,9 +240,11 @@ class Encoder(tf.keras.layers.Layer): ...@@ -239,9 +240,11 @@ class Encoder(tf.keras.layers.Layer):
'init_stochastic_depth_rate': self._init_stochastic_depth_rate, 'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed, 'add_pos_embed': self._add_pos_embed,
'pos_embed_origin_shape': self._pos_embed_origin_shape,
'pos_embed_target_shape': self._pos_embed_target_shape,
} }
base_config = super().get_config() config.update(updates)
return base_config.update(config) return config
class VisionTransformer(tf.keras.Model): class VisionTransformer(tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment