Unverified Commit ee209d4d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix PT TF ViTMAE (#16766)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 5da33f87
......@@ -860,7 +860,9 @@ class TFViTMAEDecoder(tf.keras.layers.Layer):
self.decoder_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm")
self.decoder_pred = tf.keras.layers.Dense(
config.patch_size**2 * config.num_channels, name="decoder_pred"
config.patch_size**2 * config.num_channels,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder_pred",
) # encoder to decoder
self.config = config
self.num_patches = num_patches
......
......@@ -756,7 +756,7 @@ class ViTMAEDecoder(nn.Module):
[ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
)
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size)
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.decoder_pred = nn.Linear(
config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
) # encoder to decoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment