Unverified Commit 0c7f93f5 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix nn.init.trunc_normal_ call on torch.float16 data (#21789)

fix nn.init.trunc_normal_ call on half data
parent ebf84f07
......@@ -449,17 +449,17 @@ class ViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
)
).to(module.position_embeddings.dtype)
nn.init.trunc_normal_(
module.cls_token,
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
)
).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder):
......
......@@ -474,17 +474,17 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTHybridEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
)
).to(module.position_embeddings.dtype)
nn.init.trunc_normal_(
module.cls_token,
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
)
).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
if isinstance(module, ViTHybridEncoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment