Unverified Commit f0b49015 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

🚨 🚨 🚨 Fix ViT parameter initialization (#19341)

This PR aims to rectify the discrepancy between the training performances of HF and Timm ViT implementations.

- Initializes torch and flax ViT dense layer weights with trunc_normal instead of normal (consistent with the TF implementation.
- Initializes cls_token and positional_embeddings with trunc_normal
- Updates DeiT copy to reflect the changes
parent 7e7f62bf
...@@ -402,9 +402,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -402,9 +402,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
......
...@@ -101,7 +101,9 @@ class FlaxViTPatchEmbeddings(nn.Module): ...@@ -101,7 +101,9 @@ class FlaxViTPatchEmbeddings(nn.Module):
strides=(patch_size, patch_size), strides=(patch_size, patch_size),
padding="VALID", padding="VALID",
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
) )
def __call__(self, pixel_values): def __call__(self, pixel_values):
...@@ -122,11 +124,17 @@ class FlaxViTEmbeddings(nn.Module): ...@@ -122,11 +124,17 @@ class FlaxViTEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) self.cls_token = self.param(
"cls_token",
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, 1, self.config.hidden_size),
)
self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype) self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.param( self.position_embeddings = self.param(
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) "position_embeddings",
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, num_patches + 1, self.config.hidden_size),
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -156,19 +164,25 @@ class FlaxViTSelfAttention(nn.Module): ...@@ -156,19 +164,25 @@ class FlaxViTSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias, use_bias=self.config.qkv_bias,
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias, use_bias=self.config.qkv_bias,
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias, use_bias=self.config.qkv_bias,
) )
...@@ -214,7 +228,9 @@ class FlaxViTSelfOutput(nn.Module): ...@@ -214,7 +228,9 @@ class FlaxViTSelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -253,7 +269,9 @@ class FlaxViTIntermediate(nn.Module): ...@@ -253,7 +269,9 @@ class FlaxViTIntermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -271,7 +289,9 @@ class FlaxViTOutput(nn.Module): ...@@ -271,7 +289,9 @@ class FlaxViTOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -394,7 +414,9 @@ class FlaxViTPooler(nn.Module): ...@@ -394,7 +414,9 @@ class FlaxViTPooler(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -572,7 +594,9 @@ class FlaxViTForImageClassificationModule(nn.Module): ...@@ -572,7 +594,9 @@ class FlaxViTForImageClassificationModule(nn.Module):
self.classifier = nn.Dense( self.classifier = nn.Dense(
self.config.num_labels, self.config.num_labels,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
) )
def __call__( def __call__(
......
...@@ -69,11 +69,14 @@ class TFViTEmbeddings(tf.keras.layers.Layer): ...@@ -69,11 +69,14 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight( self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size), initializer="zeros", trainable=True, name="cls_token" shape=(1, 1, self.config.hidden_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="cls_token",
) )
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
shape=(1, num_patches + 1, self.config.hidden_size), shape=(1, num_patches + 1, self.config.hidden_size),
initializer="zeros", initializer=get_initializer(self.config.initializer_range),
trainable=True, trainable=True,
name="position_embeddings", name="position_embeddings",
) )
......
...@@ -67,11 +67,17 @@ class ViTEmbeddings(nn.Module): ...@@ -67,11 +67,17 @@ class ViTEmbeddings(nn.Module):
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(
nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range)
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config) self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range
)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config self.config = config
...@@ -440,9 +446,7 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -440,9 +446,7 @@ class ViTPreTrainedModel(PreTrainedModel):
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
......
...@@ -581,7 +581,6 @@ class ViTMAEPreTrainedModel(PreTrainedModel): ...@@ -581,7 +581,6 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment