Unverified Commit f4ef78af authored by Mathieu Jouffroy's avatar Mathieu Jouffroy Committed by GitHub
Browse files

using trunc_normal for weight init & cls_token (#19486)

parent 5760a8fc
...@@ -451,7 +451,11 @@ class CvtStage(nn.Module): ...@@ -451,7 +451,11 @@ class CvtStage(nn.Module):
self.config = config self.config = config
self.stage = stage self.stage = stage
if self.config.cls_token[self.stage]: if self.config.cls_token[self.stage]:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.embed_dim[-1])) self.cls_token = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=config.initializer_range
)
)
self.embedding = CvtEmbeddings( self.embedding = CvtEmbeddings(
patch_size=config.patch_sizes[self.stage], patch_size=config.patch_sizes[self.stage],
...@@ -547,9 +551,7 @@ class CvtPreTrainedModel(PreTrainedModel): ...@@ -547,9 +551,7 @@ class CvtPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment