"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fc8a93507c365580b27612089bca59d18b66e053"
Unverified Commit ecd7de3d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Vision`] [Refactor] Initialize weights on the correct place (#20803)

* fix nit

- initialization on `_init_weights`
- fix copies

* add copied from
parent 6b5a8f83
...@@ -382,7 +382,7 @@ class ASTPreTrainedModel(PreTrainedModel): ...@@ -382,7 +382,7 @@ class ASTPreTrainedModel(PreTrainedModel):
main_input_name = "input_values" main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
......
...@@ -387,7 +387,6 @@ class DeiTEncoder(nn.Module): ...@@ -387,7 +387,6 @@ class DeiTEncoder(nn.Module):
) )
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing
class DeiTPreTrainedModel(PreTrainedModel): class DeiTPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......
...@@ -67,21 +67,11 @@ class ViTEmbeddings(nn.Module): ...@@ -67,21 +67,11 @@ class ViTEmbeddings(nn.Module):
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__() super().__init__()
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
nn.init.trunc_normal_(
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
)
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config) self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter( self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
mean=0.0,
std=config.initializer_range,
)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config self.config = config
...@@ -461,6 +451,18 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -461,6 +451,18 @@ class ViTPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
mean=0.0,
std=self.config.initializer_range,
)
nn.init.trunc_normal_(
module.cls_token,
mean=0.0,
std=self.config.initializer_range,
)
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder): if isinstance(module, ViTEncoder):
......
...@@ -59,24 +59,15 @@ class ViTHybridEmbeddings(nn.Module): ...@@ -59,24 +59,15 @@ class ViTHybridEmbeddings(nn.Module):
Construct the CLS token, position and patch embeddings. Optionally, also the mask token. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
""" """
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.__init__ with ViT->ViTHybrid
def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None: def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None:
super().__init__() super().__init__()
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
nn.init.trunc_normal_(
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
)
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTHybridPatchEmbeddings(config) self.patch_embeddings = ViTHybridPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter( self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
mean=0.0,
std=config.initializer_range,
)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config self.config = config
...@@ -485,6 +476,18 @@ class ViTHybridPreTrainedModel(PreTrainedModel): ...@@ -485,6 +476,18 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, ViTHybridEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
mean=0.0,
std=self.config.initializer_range,
)
nn.init.trunc_normal_(
module.cls_token,
mean=0.0,
std=self.config.initializer_range,
)
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
if isinstance(module, ViTHybridEncoder): if isinstance(module, ViTHybridEncoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment