Unverified Commit 0b86e330 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`CVT`] Fix module initialization issue (#21193)

fix cvt init
parent b9403e95
...@@ -451,11 +451,7 @@ class CvtStage(nn.Module): ...@@ -451,11 +451,7 @@ class CvtStage(nn.Module):
self.config = config self.config = config
self.stage = stage self.stage = stage
if self.config.cls_token[self.stage]: if self.config.cls_token[self.stage]:
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.embed_dim[-1]))
nn.init.trunc_normal_(
torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=config.initializer_range
)
)
self.embedding = CvtEmbeddings( self.embedding = CvtEmbeddings(
patch_size=config.patch_sizes[self.stage], patch_size=config.patch_sizes[self.stage],
...@@ -557,6 +553,11 @@ class CvtPreTrainedModel(PreTrainedModel): ...@@ -557,6 +553,11 @@ class CvtPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, CvtStage):
if self.config.cls_token[module.stage]:
module.cls_token.data = nn.init.trunc_normal_(
torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=self.config.initializer_range
)
CVT_START_DOCSTRING = r""" CVT_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment